[Mlir-commits] [mlir] [mlir][OpenMP] inscan reduction modifier and scan op mlir support (PR #114737)
Anchu Rajendran S
llvmlistbot at llvm.org
Wed Jan 15 11:18:27 PST 2025
https://github.com/anchuraj updated https://github.com/llvm/llvm-project/pull/114737
>From 59ac266b0251b7336059603a385688c11fde8187 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Sat, 2 Nov 2024 21:57:20 -0500
Subject: [PATCH 1/5] Changes for inscan reduction and scan op
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 71 +++++
.../mlir/Dialect/OpenMP/OpenMPEnums.td | 21 ++
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 42 ++-
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 1 +
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 255 ++++++++++++------
mlir/test/Dialect/OpenMP/invalid.mlir | 79 ++++++
mlir/test/Dialect/OpenMP/ops.mlir | 23 ++
7 files changed, 400 insertions(+), 92 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 8af054be322a55..56ecc15dfc8799 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -334,6 +334,40 @@ class OpenMP_DoacrossClauseSkip<
def OpenMP_DoacrossClause : OpenMP_DoacrossClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// V5.2: [5.4.7] `exclusive` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_ExclusiveClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+ let arguments = (ins
+ Variadic<AnyType>:$exclusive_vars
+ );
+
+ let optAssemblyFormat = [{
+ `exclusive` `(` $exclusive_vars `:` type($exclusive_vars) `)`
+ }];
+
+ let extraClassDeclaration = [{
+ bool hasExclusiveVars() {
+ return getExclusiveVars().size()>0;
+ }
+ }];
+
+ let description = [{
+ The exclusive clause is used on a separating directive that separates a
+ structured block into two structured block sequences. If it
+ is specified, the input phase excludes the preceding structured block
+ sequence and instead includes the following structured block sequence,
+ while the scan phase includes the preceding structured block sequence.
+ }];
+}
+
+def OpenMP_ExclusiveClause : OpenMP_ExclusiveClauseSkip<>;
+
//===----------------------------------------------------------------------===//
// V5.2: [10.5.1] `filter` clause
//===----------------------------------------------------------------------===//
@@ -444,6 +478,40 @@ class OpenMP_HasDeviceAddrClauseSkip<
def OpenMP_HasDeviceAddrClause : OpenMP_HasDeviceAddrClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// V5.2: [5.4.7] `inclusive` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_InclusiveClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+ let arguments = (ins
+ Variadic<AnyType>:$inclusive_vars
+ );
+
+ let optAssemblyFormat = [{
+ `inclusive` `(` $inclusive_vars `:` type($inclusive_vars) `)`
+ }];
+
+ let extraClassDeclaration = [{
+ bool hasInclusiveVars() {
+ return getInclusiveVars().size()>0;
+ }
+ }];
+
+ let description = [{
+ The inclusive clause is used on a separating directive that separates a
+ structured block into two structured block sequences. If it is specified,
+ the input phase includes the preceding structured block sequence and the
+ scan phase includes the following structured block sequence.
+ }];
+}
+
+def OpenMP_InclusiveClause : OpenMP_InclusiveClauseSkip<>;
+
+
//===----------------------------------------------------------------------===//
// V5.2: [15.1.2] `hint` clause
//===----------------------------------------------------------------------===//
@@ -544,6 +612,7 @@ class OpenMP_InReductionClauseSkip<
];
let arguments = (ins
+ OptionalAttr<ReductionModifierAttr>:$in_reduction_mod,
Variadic<OpenMP_PointerLikeType>:$in_reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$in_reduction_byref,
OptionalAttr<SymbolRefArrayAttr>:$in_reduction_syms
@@ -1100,6 +1169,7 @@ class OpenMP_ReductionClauseSkip<
];
let arguments = (ins
+ OptionalAttr<ReductionModifierAttr>:$reduction_mod,
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$reduction_byref,
OptionalAttr<SymbolRefArrayAttr>:$reduction_syms
@@ -1230,6 +1300,7 @@ class OpenMP_TaskReductionClauseSkip<
];
let arguments = (ins
+ OptionalAttr<ReductionModifierAttr>:$task_reduction_mod,
Variadic<OpenMP_PointerLikeType>:$task_reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$task_reduction_byref,
OptionalAttr<SymbolRefArrayAttr>:$task_reduction_syms
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 2091c0c76dff72..25e08aa726af40 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -179,6 +179,27 @@ def OrderModifier
def OrderModifierAttr : EnumAttr<OpenMP_Dialect, OrderModifier,
"order_mod">;
+//===----------------------------------------------------------------------===//
+// reduction_modifier enum.
+//===----------------------------------------------------------------------===//
+
+def ReductionModifierInScan : I32EnumAttrCase<"InScan", 0>;
+def ReductionModifierTask : I32EnumAttrCase<"Task", 1>;
+def ReductionModifierDefault : I32EnumAttrCase<"Default", 2>;
+
+def ReductionModifier : OpenMP_I32EnumAttr<
+ "ReductionModifier",
+ "reduction modifier", [
+ ReductionModifierInScan,
+ ReductionModifierTask,
+ ReductionModifierDefault
+ ]>;
+
+def ReductionModifierAttr : OpenMP_EnumAttr<ReductionModifier,
+ "reduction_modifier"> {
+ let assemblyFormat = "`(` $value `)`";
+}
+
//===----------------------------------------------------------------------===//
// sched_mod enum.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index c5b88904367086..6c62c83398b1a7 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -178,7 +178,7 @@ def ParallelOp : OpenMP_Op<"parallel", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
- $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+ $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
}];
@@ -223,7 +223,7 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
- $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+ $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
}];
@@ -282,7 +282,7 @@ def SectionsOp : OpenMP_Op<"sections", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
- $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+ $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
}];
@@ -469,7 +469,7 @@ def LoopOp : OpenMP_Op<"loop", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
- $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+ $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
}];
@@ -521,7 +521,7 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
- $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+ $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
}];
@@ -575,7 +575,7 @@ def SimdOp : OpenMP_Op<"simd", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
- $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+ $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
}];
@@ -702,7 +702,7 @@ def TaskOp
let assemblyFormat = clausesAssemblyFormat # [{
custom<InReductionPrivateRegion>(
- $region, $in_reduction_vars, type($in_reduction_vars),
+ $region, $in_reduction_mod, $in_reduction_vars, type($in_reduction_vars),
$in_reduction_byref, $in_reduction_syms, $private_vars,
type($private_vars), $private_syms) attr-dict
}];
@@ -780,9 +780,9 @@ def TaskloopOp : OpenMP_Op<"taskloop", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<InReductionPrivateReductionRegion>(
- $region, $in_reduction_vars, type($in_reduction_vars),
+ $region, $in_reduction_mod, $in_reduction_vars, type($in_reduction_vars),
$in_reduction_byref, $in_reduction_syms, $private_vars,
- type($private_vars), $private_syms, $reduction_vars,
+ type($private_vars), $private_syms, $reduction_mod, $reduction_vars,
type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
}];
@@ -827,7 +827,7 @@ def TaskgroupOp : OpenMP_Op<"taskgroup", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<TaskReductionRegion>(
- $region, $task_reduction_vars, type($task_reduction_vars),
+ $region, $task_reduction_mod, $task_reduction_vars, type($task_reduction_vars),
$task_reduction_byref, $task_reduction_syms) attr-dict
}];
@@ -1289,7 +1289,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<HostEvalInReductionMapPrivateRegion>(
- $region, $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
+ $region, $host_eval_vars, type($host_eval_vars), $in_reduction_mod, $in_reduction_vars,
type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
$map_vars, type($map_vars), $private_vars, type($private_vars),
$private_syms, $private_maps) attr-dict
@@ -1706,6 +1706,26 @@ def CancellationPointOp : OpenMP_Op<"cancellation_point", clauses = [
let hasVerifier = 1;
}
+def ScanOp : OpenMP_Op<"scan", [
+ AttrSizedOperandSegments, RecipeInterface, IsolatedFromAbove
+ ], clauses = [
+ OpenMP_InclusiveClause, OpenMP_ExclusiveClause]> {
+ let summary = "scan directive";
+ let description = [{
+ The scan directive allows to specify scan reduction. Scan directive
+ should be enclosed with in a parent directive along with which , a
+ reduction clause with `InScan` modifier must be specified. Scan directive
+ allows to separate code blocks to input phase and scan phase in the region
+ enclosed by the parent.
+ }] # clausesDescription;
+
+ let builders = [
+ OpBuilder<(ins CArg<"const ScanOperands &">:$clauses)>
+ ];
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// 2.19.5.7 declare reduction Directive
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index aa241b91d758ca..233739e1d6d917 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -451,6 +451,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
/* proc_bind_kind = */ omp::ClauseProcBindKindAttr{},
+ /* reduction_mod = */ nullptr,
/* reduction_vars = */ llvm::SmallVector<Value>{},
/* reduction_byref = */ DenseBoolArrayAttr{},
/* reduction_syms = */ ArrayAttr{});
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 5a619254a5ee14..2bc14f5abfc695 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -494,16 +494,22 @@ struct PrivateParseArgs {
DenseI64ArrayAttr *mapIndices = nullptr)
: vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
};
+
struct ReductionParseArgs {
+ ReductionModifierAttr &reductionMod;
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
SmallVectorImpl<Type> &types;
DenseBoolArrayAttr &byref;
ArrayAttr &syms;
- ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
+ ReductionParseArgs(ReductionModifierAttr &redMod,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
ArrayAttr &syms)
- : vars(vars), types(types), byref(byref), syms(syms) {}
+ : reductionMod(redMod), vars(vars), types(types), byref(byref),
+ syms(syms) {}
};
+
+// specifies the arguments needs for `reduction` clause
struct AllRegionParseArgs {
std::optional<MapParseArgs> hostEvalArgs;
std::optional<ReductionParseArgs> inReductionArgs;
@@ -522,7 +528,8 @@ static ParseResult parseClauseWithRegionArgs(
SmallVectorImpl<Type> &types,
SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs,
ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
- DenseBoolArrayAttr *byref = nullptr) {
+ DenseBoolArrayAttr *byref = nullptr,
+ ReductionModifierAttr *reductionMod = nullptr) {
SmallVector<SymbolRefAttr> symbolVec;
SmallVector<int64_t> mapIndicesVec;
SmallVector<bool> isByRefVec;
@@ -531,6 +538,16 @@ static ParseResult parseClauseWithRegionArgs(
if (parser.parseLParen())
return failure();
+ StringRef enumStr;
+ if (succeeded(parser.parseOptionalKeyword("Id"))) {
+ if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
+ parser.parseComma())
+ return failure();
+ std::optional<ReductionModifier> enumValue =
+ symbolizeReductionModifier(enumStr);
+ *reductionMod = ReductionModifierAttr::get(parser.getContext(), *enumValue);
+ }
+
if (parser.parseCommaSeparatedList([&]() {
if (byref)
isByRefVec.push_back(
@@ -635,16 +652,14 @@ static ParseResult parseBlockArgClause(
if (succeeded(parser.parseOptionalKeyword(keyword))) {
if (!reductionArgs)
return failure();
-
if (failed(parseClauseWithRegionArgs(
parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
- &reductionArgs->syms, /*mapIndices=*/nullptr,
- &reductionArgs->byref)))
+ &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
+ &(reductionArgs->reductionMod))))
return failure();
}
return success();
}
-
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion,
AllRegionParseArgs args) {
llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;
@@ -695,7 +710,7 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion,
static ParseResult parseHostEvalInReductionMapPrivateRegion(
OpAsmParser &parser, Region ®ion,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
- SmallVectorImpl<Type> &hostEvalTypes,
+ SmallVectorImpl<Type> &hostEvalTypes, ReductionModifierAttr &inReductionMod,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
SmallVectorImpl<Type> &inReductionTypes,
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
@@ -706,8 +721,9 @@ static ParseResult parseHostEvalInReductionMapPrivateRegion(
DenseI64ArrayAttr &privateMaps) {
AllRegionParseArgs args;
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
- args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
- inReductionByref, inReductionSyms);
+ args.inReductionArgs.emplace(inReductionMod, inReductionVars,
+ inReductionTypes, inReductionByref,
+ inReductionSyms);
args.mapArgs.emplace(mapVars, mapTypes);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
&privateMaps);
@@ -715,35 +731,38 @@ static ParseResult parseHostEvalInReductionMapPrivateRegion(
}
static ParseResult parseInReductionPrivateRegion(
- OpAsmParser &parser, Region ®ion,
+ OpAsmParser &parser, Region ®ion, ReductionModifierAttr &inReductionMod,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
SmallVectorImpl<Type> &inReductionTypes,
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
AllRegionParseArgs args;
- args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
- inReductionByref, inReductionSyms);
+ args.inReductionArgs.emplace(inReductionMod, inReductionVars,
+ inReductionTypes, inReductionByref,
+ inReductionSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
return parseBlockArgRegion(parser, region, args);
}
static ParseResult parseInReductionPrivateReductionRegion(
- OpAsmParser &parser, Region ®ion,
+ OpAsmParser &parser, Region ®ion, ReductionModifierAttr &inReductionMod,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
SmallVectorImpl<Type> &inReductionTypes,
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+ ReductionModifierAttr &reductionMod,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
ArrayAttr &reductionSyms) {
AllRegionParseArgs args;
- args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
- inReductionByref, inReductionSyms);
+ args.inReductionArgs.emplace(inReductionMod, inReductionVars,
+ inReductionTypes, inReductionByref,
+ inReductionSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
- args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
- reductionSyms);
+ args.reductionArgs.emplace(reductionMod, reductionVars, reductionTypes,
+ reductionByref, reductionSyms);
return parseBlockArgRegion(parser, region, args);
}
@@ -760,24 +779,27 @@ static ParseResult parsePrivateReductionRegion(
OpAsmParser &parser, Region ®ion,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+ ReductionModifierAttr &reductionMod,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
ArrayAttr &reductionSyms) {
AllRegionParseArgs args;
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
- args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
- reductionSyms);
+ args.reductionArgs.emplace(reductionMod, reductionVars, reductionTypes,
+ reductionByref, reductionSyms);
return parseBlockArgRegion(parser, region, args);
}
static ParseResult parseTaskReductionRegion(
OpAsmParser &parser, Region ®ion,
+ ReductionModifierAttr &taskReductionMod,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &taskReductionVars,
SmallVectorImpl<Type> &taskReductionTypes,
DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
AllRegionParseArgs args;
- args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
- taskReductionByref, taskReductionSyms);
+ args.taskReductionArgs.emplace(taskReductionMod, taskReductionVars,
+ taskReductionTypes, taskReductionByref,
+ taskReductionSyms);
return parseBlockArgRegion(parser, region, args);
}
@@ -813,13 +835,15 @@ struct PrivatePrintArgs {
: vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
};
struct ReductionPrintArgs {
+ ReductionModifierAttr reductionMod;
ValueRange vars;
TypeRange types;
DenseBoolArrayAttr byref;
ArrayAttr syms;
- ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
- ArrayAttr syms)
- : vars(vars), types(types), byref(byref), syms(syms) {}
+ ReductionPrintArgs(ReductionModifierAttr reductionMod, ValueRange vars,
+ TypeRange types, DenseBoolArrayAttr byref, ArrayAttr syms)
+ : reductionMod(reductionMod), vars(vars), types(types), byref(byref),
+ syms(syms) {}
};
struct AllRegionPrintArgs {
std::optional<MapPrintArgs> hostEvalArgs;
@@ -833,18 +857,21 @@ struct AllRegionPrintArgs {
};
} // namespace
-static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
- StringRef clauseName,
- ValueRange argsSubrange,
- ValueRange operands, TypeRange types,
- ArrayAttr symbols = nullptr,
- DenseI64ArrayAttr mapIndices = nullptr,
- DenseBoolArrayAttr byref = nullptr) {
+static void printClauseWithRegionArgs(
+ OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
+ ValueRange argsSubrange, ValueRange operands, TypeRange types,
+ ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
+ DenseBoolArrayAttr byref = nullptr,
+ ReductionModifierAttr reductionMod = nullptr) {
if (argsSubrange.empty())
return;
p << clauseName << "(";
+ if (reductionMod) {
+ p << "Id: " << stringifyReductionModifier(reductionMod.getValue()) << ", ";
+ }
+
if (!symbols) {
llvm::SmallVector<Attribute> values(operands.size(), nullptr);
symbols = ArrayAttr::get(ctx, values);
@@ -902,10 +929,10 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
ValueRange argsSubrange,
std::optional<ReductionPrintArgs> reductionArgs) {
if (reductionArgs)
- printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
- reductionArgs->vars, reductionArgs->types,
- reductionArgs->syms, /*mapIndices=*/nullptr,
- reductionArgs->byref);
+ printClauseWithRegionArgs(
+ p, ctx, clauseName, argsSubrange, reductionArgs->vars,
+ reductionArgs->types, reductionArgs->syms, /*mapIndices=*/nullptr,
+ reductionArgs->byref, reductionArgs->reductionMod);
}
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
@@ -937,46 +964,53 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
static void printHostEvalInReductionMapPrivateRegion(
OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hostEvalVars,
- TypeRange hostEvalTypes, ValueRange inReductionVars,
- TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
- ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
- ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
+ TypeRange hostEvalTypes, ReductionModifierAttr inReductionMod,
+ ValueRange inReductionVars, TypeRange inReductionTypes,
+ DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
+ ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
+ TypeRange privateTypes, ArrayAttr privateSyms,
DenseI64ArrayAttr privateMaps) {
AllRegionPrintArgs args;
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
- args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
- inReductionByref, inReductionSyms);
+ args.inReductionArgs.emplace(inReductionMod, inReductionVars,
+ inReductionTypes, inReductionByref,
+ inReductionSyms);
args.mapArgs.emplace(mapVars, mapTypes);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
printBlockArgRegion(p, op, region, args);
}
static void printInReductionPrivateRegion(
- OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
+ OpAsmPrinter &p, Operation *op, Region ®ion,
+ ReductionModifierAttr inReductionMod, ValueRange inReductionVars,
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
ArrayAttr privateSyms) {
AllRegionPrintArgs args;
- args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
- inReductionByref, inReductionSyms);
+ args.inReductionArgs.emplace(inReductionMod, inReductionVars,
+ inReductionTypes, inReductionByref,
+ inReductionSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
/*mapIndices=*/nullptr);
printBlockArgRegion(p, op, region, args);
}
static void printInReductionPrivateReductionRegion(
- OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
+ OpAsmPrinter &p, Operation *op, Region ®ion,
+ ReductionModifierAttr inReductionMod, ValueRange inReductionVars,
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
- ArrayAttr privateSyms, ValueRange reductionVars, TypeRange reductionTypes,
+ ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
+ ValueRange reductionVars, TypeRange reductionTypes,
DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
AllRegionPrintArgs args;
- args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
- inReductionByref, inReductionSyms);
+ args.inReductionArgs.emplace(inReductionMod, inReductionVars,
+ inReductionTypes, inReductionByref,
+ inReductionSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
/*mapIndices=*/nullptr);
- args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
- reductionSyms);
+ args.reductionArgs.emplace(reductionMod, reductionVars, reductionTypes,
+ reductionByref, reductionSyms);
printBlockArgRegion(p, op, region, args);
}
@@ -991,26 +1025,29 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
static void printPrivateReductionRegion(
OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars,
- TypeRange privateTypes, ArrayAttr privateSyms, ValueRange reductionVars,
+ TypeRange privateTypes, ArrayAttr privateSyms,
+ ReductionModifierAttr reductionMod, ValueRange reductionVars,
TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
ArrayAttr reductionSyms) {
AllRegionPrintArgs args;
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
/*mapIndices=*/nullptr);
- args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
- reductionSyms);
+ args.reductionArgs.emplace(reductionMod, reductionVars, reductionTypes,
+ reductionByref, reductionSyms);
printBlockArgRegion(p, op, region, args);
}
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op,
Region ®ion,
+ ReductionModifierAttr taskReductionMod,
ValueRange taskReductionVars,
TypeRange taskReductionTypes,
DenseBoolArrayAttr taskReductionByref,
ArrayAttr taskReductionSyms) {
AllRegionPrintArgs args;
- args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
- taskReductionByref, taskReductionSyms);
+ args.taskReductionArgs.emplace(taskReductionMod, taskReductionVars,
+ taskReductionTypes, taskReductionByref,
+ taskReductionSyms);
printBlockArgRegion(p, op, region, args);
}
@@ -1727,6 +1764,7 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
clauses.hostEvalVars, clauses.ifExpr,
+ /*in_reduction_mod=*/nullptr,
/*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
clauses.mapVars, clauses.nowait, clauses.privateVars,
@@ -1942,7 +1980,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
/*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr,
- /*reduction_vars=*/ValueRange(),
+ /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
/*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
state.addAttributes(attributes);
}
@@ -1953,7 +1991,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
clauses.ifExpr, clauses.numThreads, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
- clauses.procBindKind, clauses.reductionVars,
+ clauses.procBindKind, clauses.reductionMod,
+ clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms));
}
@@ -2052,12 +2091,13 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
const TeamsOperands &clauses) {
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: privateVars, privateSyms.
- TeamsOp::build(
- builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
- /*private_vars=*/{}, /*private_syms=*/nullptr, clauses.reductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimit);
+ TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
+ clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
+ /*private_vars=*/{}, /*private_syms=*/nullptr,
+ clauses.reductionMod, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+ makeArrayAttr(ctx, clauses.reductionSyms),
+ clauses.threadLimit);
}
LogicalResult TeamsOp::verify() {
@@ -2114,7 +2154,8 @@ void SectionsOp::build(OpBuilder &builder, OperationState &state,
// TODO Store clauses in op: privateVars, privateSyms.
SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
clauses.nowait, /*private_vars=*/{},
- /*private_syms=*/nullptr, clauses.reductionVars,
+ /*private_syms=*/nullptr, clauses.reductionMod,
+ clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms));
}
@@ -2221,7 +2262,7 @@ void LoopOp::build(OpBuilder &builder, OperationState &state,
LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms), clauses.order,
- clauses.orderMod, clauses.reductionVars,
+ clauses.orderMod, clauses.reductionMod, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms));
}
@@ -2249,7 +2290,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
/*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
/*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
/*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
- /*reduction_vars=*/ValueRange(), /*reduction_byref=*/nullptr,
+ nullptr, /*reduction_vars=*/ValueRange(), /*reduction_byref=*/nullptr,
/*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
/*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
/*schedule_simd=*/false);
@@ -2261,15 +2302,16 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
// TODO: Store clauses in op: allocateVars, allocatorVars, privateVars,
// privateSyms.
- WsloopOp::build(
- builder, state,
- /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
- clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
- clauses.ordered, clauses.privateVars,
- makeArrayAttr(ctx, clauses.privateSyms), clauses.reductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
- clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
+ WsloopOp::build(builder, state,
+ /*allocate_vars=*/{}, /*allocator_vars=*/{},
+ clauses.linearVars, clauses.linearStepVars, clauses.nowait,
+ clauses.order, clauses.orderMod, clauses.ordered,
+ clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
+ clauses.reductionMod, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+ makeArrayAttr(ctx, clauses.reductionSyms),
+ clauses.scheduleKind, clauses.scheduleChunk,
+ clauses.scheduleMod, clauses.scheduleSimd);
}
LogicalResult WsloopOp::verify() {
@@ -2316,7 +2358,7 @@ void SimdOp::build(OpBuilder &builder, OperationState &state,
/*linear_vars=*/{}, /*linear_step_vars=*/{},
clauses.nontemporalVars, clauses.order, clauses.orderMod,
clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
- clauses.reductionVars,
+ clauses.reductionMod, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
clauses.simdlen);
@@ -2497,7 +2539,8 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
- clauses.final, clauses.ifExpr, clauses.inReductionVars,
+ clauses.final, clauses.ifExpr, clauses.inReductionMod,
+ clauses.inReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
clauses.priority, /*private_vars=*/clauses.privateVars,
@@ -2523,7 +2566,8 @@ void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
const TaskgroupOperands &clauses) {
MLIRContext *ctx = builder.getContext();
TaskgroupOp::build(builder, state, clauses.allocateVars,
- clauses.allocatorVars, clauses.taskReductionVars,
+ clauses.allocatorVars, clauses.taskReductionMod,
+ clauses.taskReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
makeArrayAttr(ctx, clauses.taskReductionSyms));
}
@@ -2544,11 +2588,12 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
// TODO Store clauses in op: privateVars, privateSyms.
TaskloopOp::build(
builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars,
+ clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionMod,
+ clauses.inReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
- /*private_syms=*/nullptr, clauses.reductionVars,
+ /*private_syms=*/nullptr, clauses.reductionMod, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
}
@@ -3125,6 +3170,54 @@ void MaskedOp::build(OpBuilder &builder, OperationState &state,
MaskedOp::build(builder, state, clauses.filteredThreadId);
}
+//===----------------------------------------------------------------------===//
+// Spec 5.2: Scan construct (5.6)
+//===----------------------------------------------------------------------===//
+
+void ScanOp::build(OpBuilder &builder, OperationState &state,
+ const ScanOperands &clauses) {
+ ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
+}
+
+LogicalResult ScanOp::verify() {
+ if (hasExclusiveVars() && hasInclusiveVars()) {
+ return emitError(
+ "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
+ }
+ const OperandRange &scanVars =
+ hasExclusiveVars() ? getExclusiveVars() : getInclusiveVars();
+ auto verifyScanVarsInReduction = [&scanVars](OperandRange reductionVars) {
+ for (const auto &it : scanVars)
+ if (!llvm::is_contained(reductionVars, it))
+ return false;
+ return true;
+ };
+ if (mlir::omp::WsloopOp parentOp =
+ (*this)->getParentOfType<mlir::omp::WsloopOp>()) {
+ if (parentOp.getReductionModAttr() &&
+ parentOp.getReductionModAttr().getValue() ==
+ mlir::omp::ReductionModifier::InScan) {
+ if (!verifyScanVarsInReduction(parentOp.getReductionVars())) {
+ return emitError(
+ "List item should appear in REDUCTION clause of the parent");
+ }
+ return success();
+ }
+ } else if (mlir::omp::SimdOp parentOp =
+ (*this)->getParentOfType<mlir::omp::SimdOp>()) {
+ if (parentOp.getReductionModAttr().getValue() ==
+ mlir::omp::ReductionModifier::InScan) {
+ if (!verifyScanVarsInReduction(parentOp.getReductionVars())) {
+ return emitError(
+ "List item should appear in REDUCTION clause of the parent");
+ }
+ return success();
+ }
+ }
+ return emitError("Scan Operation should be enclosed within a parent "
+ "WORSKSHARING LOOP or SIMD with INSCAN reduction modifier");
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index c611614265592c..5215e363d43e66 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1825,6 +1825,85 @@ func.func @omp_cancellationpoint2() {
// -----
+omp.declare_reduction @add_f32 : f32
+init {
+ ^bb0(%arg: f32):
+ %0 = arith.constant 0.0 : f32
+ omp.yield (%0 : f32)
+}
+combiner {
+ ^bb1(%arg0: f32, %arg1: f32):
+ %1 = arith.addf %arg0, %arg1 : f32
+ omp.yield (%1 : f32)
+}
+
+func.func @scan_test_1(%lb: i32, %ub: i32, %step: i32) {
+ %test1f32 = "test.f32"() : () -> (!llvm.ptr)
+ %test2f32 = "test.f32"() : () -> (!llvm.ptr)
+ omp.wsloop reduction(Id:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // expected-error @below {{List item should appear in REDUCTION clause of the parent}}
+ omp.scan inclusive(%test2f32 : !llvm.ptr)
+ omp.yield
+ }
+ }
+ return
+}
+
+// -----
+
+omp.declare_reduction @add_f32 : f32
+init {
+ ^bb0(%arg: f32):
+ %0 = arith.constant 0.0 : f32
+ omp.yield (%0 : f32)
+}
+combiner {
+ ^bb1(%arg0: f32, %arg1: f32):
+ %1 = arith.addf %arg0, %arg1 : f32
+ omp.yield (%1 : f32)
+}
+
+func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
+ %test1f32 = "test.f32"() : () -> (!llvm.ptr)
+ omp.wsloop reduction(Id:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // expected-error @below {{Exactly one of EXCLUSIVE or INCLUSIVE clause is expected}}
+ omp.scan inclusive(%test1f32 : !llvm.ptr) exclusive(%test1f32: !llvm.ptr)
+ omp.yield
+ }
+ }
+ return
+}
+
+// -----
+
+omp.declare_reduction @add_f32 : f32
+init {
+ ^bb0(%arg: f32):
+ %0 = arith.constant 0.0 : f32
+ omp.yield (%0 : f32)
+}
+combiner {
+ ^bb1(%arg0: f32, %arg1: f32):
+ %1 = arith.addf %arg0, %arg1 : f32
+ omp.yield (%1 : f32)
+}
+
+func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
+ %test1f32 = "test.f32"() : () -> (!llvm.ptr)
+ omp.taskloop reduction(Id:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // expected-error @below {{Scan Operation should be enclosed within a parent WORSKSHARING LOOP or SIMD with INSCAN reduction modifier}}
+ omp.scan inclusive(%test1f32 : !llvm.ptr)
+ omp.yield
+ }
+ }
+ return
+}
+
+// -----
+
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
%testmemref = "test.memref"() : () -> (memref<i32>)
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index b1901c333ade8d..25610d030d3f1c 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -900,6 +900,29 @@ func.func @wsloop_reduction(%lb : index, %ub : index, %step : index) {
return
}
+// CHECK-LABEL: func @wsloop_inscan_reduction
+func.func @wsloop_inscan_reduction(%lb : index, %ub : index, %step : index) {
+ %c1 = arith.constant 1 : i32
+ %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
+ // CHECK: reduction(Id: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
+ omp.wsloop reduction(Id:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ // CHECK: omp.scan inclusive(%{{.*}} : !llvm.ptr)
+ omp.scan inclusive(%0 : !llvm.ptr)
+ omp.yield
+ }
+ }
+ // CHECK: reduction(Id: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
+ omp.wsloop reduction(Id:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ // CHECK: omp.scan exclusive(%{{.*}} : !llvm.ptr)
+ omp.scan exclusive(%0 : !llvm.ptr)
+ omp.yield
+ }
+ }
+ return
+}
+
// CHECK-LABEL: func @wsloop_reduction_byref
func.func @wsloop_reduction_byref(%lb : index, %ub : index, %step : index) {
%c1 = arith.constant 1 : i32
>From 1be9aa4ef98a26d67cf97a2c0c5c849bbec388b4 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Mon, 18 Nov 2024 16:43:45 -0600
Subject: [PATCH 2/5] R2: Addressing a few review comments
---
mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td | 10 +++++-----
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 2 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 8 +++++---
mlir/test/Dialect/OpenMP/invalid.mlir | 2 +-
4 files changed, 12 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 25e08aa726af40..129bd03f809c49 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -183,16 +183,16 @@ def OrderModifierAttr : EnumAttr<OpenMP_Dialect, OrderModifier,
// reduction_modifier enum.
//===----------------------------------------------------------------------===//
-def ReductionModifierInScan : I32EnumAttrCase<"InScan", 0>;
-def ReductionModifierTask : I32EnumAttrCase<"Task", 1>;
-def ReductionModifierDefault : I32EnumAttrCase<"Default", 2>;
+def ReductionModifierDefault : I32EnumAttrCase<"Default", 0>;
+def ReductionModifierInScan : I32EnumAttrCase<"InScan", 1>;
+def ReductionModifierTask : I32EnumAttrCase<"Task", 2>;
def ReductionModifier : OpenMP_I32EnumAttr<
"ReductionModifier",
"reduction modifier", [
+ ReductionModifierDefault,
ReductionModifierInScan,
- ReductionModifierTask,
- ReductionModifierDefault
+ ReductionModifierTask
]>;
def ReductionModifierAttr : OpenMP_EnumAttr<ReductionModifier,
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 6c62c83398b1a7..cc375a0c021e81 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1713,7 +1713,7 @@ def ScanOp : OpenMP_Op<"scan", [
let summary = "scan directive";
let description = [{
The scan directive allows to specify scan reduction. Scan directive
- should be enclosed with in a parent directive along with which , a
+ should be enclosed with in a parent directive along with which, a
reduction clause with `InScan` modifier must be specified. Scan directive
allows to separate code blocks to input phase and scan phase in the region
enclosed by the parent.
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 2bc14f5abfc695..08df20fbb349cf 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2290,7 +2290,8 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
/*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
/*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
/*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
- nullptr, /*reduction_vars=*/ValueRange(), /*reduction_byref=*/nullptr,
+ /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
+ /*reduction_byref=*/nullptr,
/*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
/*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
/*schedule_simd=*/false);
@@ -3214,8 +3215,9 @@ LogicalResult ScanOp::verify() {
return success();
}
}
- return emitError("Scan Operation should be enclosed within a parent "
- "WORSKSHARING LOOP or SIMD with INSCAN reduction modifier");
+ return emitError("SCAN directive needs to be enclosed within a parent "
+ "worksharing loop construct or SIMD construct with INSCAN "
+ "reduction modifier");
}
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 5215e363d43e66..d21c3e513277f3 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1894,7 +1894,7 @@ func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
%test1f32 = "test.f32"() : () -> (!llvm.ptr)
omp.taskloop reduction(Id:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
- // expected-error @below {{Scan Operation should be enclosed within a parent WORSKSHARING LOOP or SIMD with INSCAN reduction modifier}}
+ // expected-error @below {{SCAN directive needs to be enclosed within a parent worksharing loop construct or SIMD construct with INSCAN reduction modifier}}
omp.scan inclusive(%test1f32 : !llvm.ptr)
omp.yield
}
>From eb274aae91fbaf06d0b66cc1eb8a09209a62ca57 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Fri, 6 Dec 2024 17:36:20 -0600
Subject: [PATCH 3/5] qR3: Addressing a few review comments
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 12 +-
.../mlir/Dialect/OpenMP/OpenMPEnums.td | 2 +-
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 20 +--
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 136 ++++++++----------
mlir/test/Dialect/OpenMP/invalid.mlir | 6 +-
mlir/test/Dialect/OpenMP/ops.mlir | 8 +-
6 files changed, 85 insertions(+), 99 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 56ecc15dfc8799..d71dc70ed2d963 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -353,7 +353,7 @@ class OpenMP_ExclusiveClauseSkip<
let extraClassDeclaration = [{
bool hasExclusiveVars() {
- return getExclusiveVars().size()>0;
+ return !getExclusiveVars().empty();
}
}];
@@ -363,6 +363,9 @@ class OpenMP_ExclusiveClauseSkip<
is specified, the input phase excludes the preceding structured block
sequence and instead includes the following structured block sequence,
while the scan phase includes the preceding structured block sequence.
+
+ The `exclusive_vars` is a variadic list of operands that specifies the
+ scan-reduction accumulator symbols.
}];
}
@@ -497,7 +500,7 @@ class OpenMP_InclusiveClauseSkip<
let extraClassDeclaration = [{
bool hasInclusiveVars() {
- return getInclusiveVars().size()>0;
+ return !getInclusiveVars().empty();
}
}];
@@ -506,6 +509,9 @@ class OpenMP_InclusiveClauseSkip<
structured block into two structured block sequences. If it is specified,
the input phase includes the preceding structured block sequence and the
scan phase includes the following structured block sequence.
+
+ The `inclusive_vars` is a variadic list of operands that specifies the
+ scan-reduction accumulator symbols.
}];
}
@@ -612,7 +618,6 @@ class OpenMP_InReductionClauseSkip<
];
let arguments = (ins
- OptionalAttr<ReductionModifierAttr>:$in_reduction_mod,
Variadic<OpenMP_PointerLikeType>:$in_reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$in_reduction_byref,
OptionalAttr<SymbolRefArrayAttr>:$in_reduction_syms
@@ -1300,7 +1305,6 @@ class OpenMP_TaskReductionClauseSkip<
];
let arguments = (ins
- OptionalAttr<ReductionModifierAttr>:$task_reduction_mod,
Variadic<OpenMP_PointerLikeType>:$task_reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$task_reduction_byref,
OptionalAttr<SymbolRefArrayAttr>:$task_reduction_syms
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 129bd03f809c49..22777b4f0964b3 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -196,7 +196,7 @@ def ReductionModifier : OpenMP_I32EnumAttr<
]>;
def ReductionModifierAttr : OpenMP_EnumAttr<ReductionModifier,
- "reduction_modifier"> {
+ "reduction_modifier"> {
let assemblyFormat = "`(` $value `)`";
}
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index cc375a0c021e81..e653d0cf7f2f99 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -702,7 +702,7 @@ def TaskOp
let assemblyFormat = clausesAssemblyFormat # [{
custom<InReductionPrivateRegion>(
- $region, $in_reduction_mod, $in_reduction_vars, type($in_reduction_vars),
+ $region, $in_reduction_vars, type($in_reduction_vars),
$in_reduction_byref, $in_reduction_syms, $private_vars,
type($private_vars), $private_syms) attr-dict
}];
@@ -780,7 +780,7 @@ def TaskloopOp : OpenMP_Op<"taskloop", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<InReductionPrivateReductionRegion>(
- $region, $in_reduction_mod, $in_reduction_vars, type($in_reduction_vars),
+ $region, $in_reduction_vars, type($in_reduction_vars),
$in_reduction_byref, $in_reduction_syms, $private_vars,
type($private_vars), $private_syms, $reduction_mod, $reduction_vars,
type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
@@ -827,7 +827,7 @@ def TaskgroupOp : OpenMP_Op<"taskgroup", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<TaskReductionRegion>(
- $region, $task_reduction_mod, $task_reduction_vars, type($task_reduction_vars),
+ $region, $task_reduction_vars, type($task_reduction_vars),
$task_reduction_byref, $task_reduction_syms) attr-dict
}];
@@ -1289,7 +1289,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<HostEvalInReductionMapPrivateRegion>(
- $region, $host_eval_vars, type($host_eval_vars), $in_reduction_mod, $in_reduction_vars,
+ $region, $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
$map_vars, type($map_vars), $private_vars, type($private_vars),
$private_syms, $private_maps) attr-dict
@@ -1707,15 +1707,15 @@ def CancellationPointOp : OpenMP_Op<"cancellation_point", clauses = [
}
def ScanOp : OpenMP_Op<"scan", [
- AttrSizedOperandSegments, RecipeInterface, IsolatedFromAbove
+ AttrSizedOperandSegments
], clauses = [
- OpenMP_InclusiveClause, OpenMP_ExclusiveClause]> {
+ OpenMP_InclusiveClause, OpenMP_ExclusiveClause]> {
let summary = "scan directive";
let description = [{
- The scan directive allows to specify scan reduction. Scan directive
- should be enclosed with in a parent directive along with which, a
- reduction clause with `InScan` modifier must be specified. Scan directive
- allows to separate code blocks to input phase and scan phase in the region
+ The scan directive allows to specify scan reductions. It should be
+ enclosed within a parent directive along with which a reduction clause
+ with `inscan` modifier must be specified. The scan directive allows to
+ split code blocks into input phase and scan phase in the region
enclosed by the parent.
}] # clausesDescription;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 08df20fbb349cf..65b6496ea5b31d 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -496,20 +496,17 @@ struct PrivateParseArgs {
};
struct ReductionParseArgs {
- ReductionModifierAttr &reductionMod;
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
SmallVectorImpl<Type> &types;
DenseBoolArrayAttr &byref;
ArrayAttr &syms;
- ReductionParseArgs(ReductionModifierAttr &redMod,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
+ ReductionModifierAttr *modifier;
+ ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
- ArrayAttr &syms)
- : reductionMod(redMod), vars(vars), types(types), byref(byref),
- syms(syms) {}
+ ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
+ : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
};
-// specifies the arguments needs for `reduction` clause
struct AllRegionParseArgs {
std::optional<MapParseArgs> hostEvalArgs;
std::optional<ReductionParseArgs> inReductionArgs;
@@ -529,7 +526,7 @@ static ParseResult parseClauseWithRegionArgs(
SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs,
ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
DenseBoolArrayAttr *byref = nullptr,
- ReductionModifierAttr *reductionMod = nullptr) {
+ ReductionModifierAttr *modifier = nullptr) {
SmallVector<SymbolRefAttr> symbolVec;
SmallVector<int64_t> mapIndicesVec;
SmallVector<bool> isByRefVec;
@@ -539,13 +536,17 @@ static ParseResult parseClauseWithRegionArgs(
return failure();
StringRef enumStr;
- if (succeeded(parser.parseOptionalKeyword("Id"))) {
+ if (succeeded(parser.parseOptionalKeyword("mod"))) {
if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
parser.parseComma())
return failure();
std::optional<ReductionModifier> enumValue =
symbolizeReductionModifier(enumStr);
- *reductionMod = ReductionModifierAttr::get(parser.getContext(), *enumValue);
+ if (!enumValue.has_value())
+ return failure();
+ *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
+ if (!*modifier)
+ return failure();
}
if (parser.parseCommaSeparatedList([&]() {
@@ -655,7 +656,7 @@ static ParseResult parseBlockArgClause(
if (failed(parseClauseWithRegionArgs(
parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
&reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
- &(reductionArgs->reductionMod))))
+ reductionArgs->modifier)))
return failure();
}
return success();
@@ -710,7 +711,7 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion,
static ParseResult parseHostEvalInReductionMapPrivateRegion(
OpAsmParser &parser, Region ®ion,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
- SmallVectorImpl<Type> &hostEvalTypes, ReductionModifierAttr &inReductionMod,
+ SmallVectorImpl<Type> &hostEvalTypes,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
SmallVectorImpl<Type> &inReductionTypes,
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
@@ -721,9 +722,8 @@ static ParseResult parseHostEvalInReductionMapPrivateRegion(
DenseI64ArrayAttr &privateMaps) {
AllRegionParseArgs args;
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
- args.inReductionArgs.emplace(inReductionMod, inReductionVars,
- inReductionTypes, inReductionByref,
- inReductionSyms);
+ args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
+ inReductionByref, inReductionSyms);
args.mapArgs.emplace(mapVars, mapTypes);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
&privateMaps);
@@ -731,22 +731,21 @@ static ParseResult parseHostEvalInReductionMapPrivateRegion(
}
static ParseResult parseInReductionPrivateRegion(
- OpAsmParser &parser, Region ®ion, ReductionModifierAttr &inReductionMod,
+ OpAsmParser &parser, Region ®ion,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
SmallVectorImpl<Type> &inReductionTypes,
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
AllRegionParseArgs args;
- args.inReductionArgs.emplace(inReductionMod, inReductionVars,
- inReductionTypes, inReductionByref,
- inReductionSyms);
+ args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
+ inReductionByref, inReductionSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
return parseBlockArgRegion(parser, region, args);
}
static ParseResult parseInReductionPrivateReductionRegion(
- OpAsmParser &parser, Region ®ion, ReductionModifierAttr &inReductionMod,
+ OpAsmParser &parser, Region ®ion,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
SmallVectorImpl<Type> &inReductionTypes,
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
@@ -757,12 +756,11 @@ static ParseResult parseInReductionPrivateReductionRegion(
SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
ArrayAttr &reductionSyms) {
AllRegionParseArgs args;
- args.inReductionArgs.emplace(inReductionMod, inReductionVars,
- inReductionTypes, inReductionByref,
- inReductionSyms);
+ args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
+ inReductionByref, inReductionSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
- args.reductionArgs.emplace(reductionMod, reductionVars, reductionTypes,
- reductionByref, reductionSyms);
+ args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
+ reductionSyms, &reductionMod);
return parseBlockArgRegion(parser, region, args);
}
@@ -785,21 +783,19 @@ static ParseResult parsePrivateReductionRegion(
ArrayAttr &reductionSyms) {
AllRegionParseArgs args;
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
- args.reductionArgs.emplace(reductionMod, reductionVars, reductionTypes,
- reductionByref, reductionSyms);
+ args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
+ reductionSyms, &reductionMod);
return parseBlockArgRegion(parser, region, args);
}
static ParseResult parseTaskReductionRegion(
OpAsmParser &parser, Region ®ion,
- ReductionModifierAttr &taskReductionMod,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &taskReductionVars,
SmallVectorImpl<Type> &taskReductionTypes,
DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
AllRegionParseArgs args;
- args.taskReductionArgs.emplace(taskReductionMod, taskReductionVars,
- taskReductionTypes, taskReductionByref,
- taskReductionSyms);
+ args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
+ taskReductionByref, taskReductionSyms);
return parseBlockArgRegion(parser, region, args);
}
@@ -835,15 +831,14 @@ struct PrivatePrintArgs {
: vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
};
struct ReductionPrintArgs {
- ReductionModifierAttr reductionMod;
ValueRange vars;
TypeRange types;
DenseBoolArrayAttr byref;
ArrayAttr syms;
- ReductionPrintArgs(ReductionModifierAttr reductionMod, ValueRange vars,
- TypeRange types, DenseBoolArrayAttr byref, ArrayAttr syms)
- : reductionMod(reductionMod), vars(vars), types(types), byref(byref),
- syms(syms) {}
+ ReductionModifierAttr *modifier;
+ ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
+ ArrayAttr syms, ReductionModifierAttr *mod = nullptr)
+ : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
};
struct AllRegionPrintArgs {
std::optional<MapPrintArgs> hostEvalArgs;
@@ -862,15 +857,14 @@ static void printClauseWithRegionArgs(
ValueRange argsSubrange, ValueRange operands, TypeRange types,
ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
DenseBoolArrayAttr byref = nullptr,
- ReductionModifierAttr reductionMod = nullptr) {
+ ReductionModifierAttr *modifier = nullptr) {
if (argsSubrange.empty())
return;
p << clauseName << "(";
- if (reductionMod) {
- p << "Id: " << stringifyReductionModifier(reductionMod.getValue()) << ", ";
- }
+ if (modifier && *modifier)
+ p << "mod: " << stringifyReductionModifier(modifier->getValue()) << ", ";
if (!symbols) {
llvm::SmallVector<Attribute> values(operands.size(), nullptr);
@@ -929,10 +923,10 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
ValueRange argsSubrange,
std::optional<ReductionPrintArgs> reductionArgs) {
if (reductionArgs)
- printClauseWithRegionArgs(
- p, ctx, clauseName, argsSubrange, reductionArgs->vars,
- reductionArgs->types, reductionArgs->syms, /*mapIndices=*/nullptr,
- reductionArgs->byref, reductionArgs->reductionMod);
+ printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
+ reductionArgs->vars, reductionArgs->types,
+ reductionArgs->syms, /*mapIndices=*/nullptr,
+ reductionArgs->byref, reductionArgs->modifier);
}
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
@@ -964,53 +958,47 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
static void printHostEvalInReductionMapPrivateRegion(
OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hostEvalVars,
- TypeRange hostEvalTypes, ReductionModifierAttr inReductionMod,
- ValueRange inReductionVars, TypeRange inReductionTypes,
- DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
- ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
- TypeRange privateTypes, ArrayAttr privateSyms,
+ TypeRange hostEvalTypes, ValueRange inReductionVars,
+ TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
+ ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
+ ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
DenseI64ArrayAttr privateMaps) {
AllRegionPrintArgs args;
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
- args.inReductionArgs.emplace(inReductionMod, inReductionVars,
- inReductionTypes, inReductionByref,
- inReductionSyms);
+ args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
+ inReductionByref, inReductionSyms);
args.mapArgs.emplace(mapVars, mapTypes);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
printBlockArgRegion(p, op, region, args);
}
static void printInReductionPrivateRegion(
- OpAsmPrinter &p, Operation *op, Region ®ion,
- ReductionModifierAttr inReductionMod, ValueRange inReductionVars,
+ OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
ArrayAttr privateSyms) {
AllRegionPrintArgs args;
- args.inReductionArgs.emplace(inReductionMod, inReductionVars,
- inReductionTypes, inReductionByref,
- inReductionSyms);
+ args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
+ inReductionByref, inReductionSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
/*mapIndices=*/nullptr);
printBlockArgRegion(p, op, region, args);
}
static void printInReductionPrivateReductionRegion(
- OpAsmPrinter &p, Operation *op, Region ®ion,
- ReductionModifierAttr inReductionMod, ValueRange inReductionVars,
+ OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
ValueRange reductionVars, TypeRange reductionTypes,
DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
AllRegionPrintArgs args;
- args.inReductionArgs.emplace(inReductionMod, inReductionVars,
- inReductionTypes, inReductionByref,
- inReductionSyms);
+ args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
+ inReductionByref, inReductionSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
/*mapIndices=*/nullptr);
- args.reductionArgs.emplace(reductionMod, reductionVars, reductionTypes,
- reductionByref, reductionSyms);
+ args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
+ reductionSyms, &reductionMod);
printBlockArgRegion(p, op, region, args);
}
@@ -1032,22 +1020,20 @@ static void printPrivateReductionRegion(
AllRegionPrintArgs args;
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
/*mapIndices=*/nullptr);
- args.reductionArgs.emplace(reductionMod, reductionVars, reductionTypes,
- reductionByref, reductionSyms);
+ args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
+ reductionSyms, &reductionMod);
printBlockArgRegion(p, op, region, args);
}
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op,
Region ®ion,
- ReductionModifierAttr taskReductionMod,
ValueRange taskReductionVars,
TypeRange taskReductionTypes,
DenseBoolArrayAttr taskReductionByref,
ArrayAttr taskReductionSyms) {
AllRegionPrintArgs args;
- args.taskReductionArgs.emplace(taskReductionMod, taskReductionVars,
- taskReductionTypes, taskReductionByref,
- taskReductionSyms);
+ args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
+ taskReductionByref, taskReductionSyms);
printBlockArgRegion(p, op, region, args);
}
@@ -1764,7 +1750,6 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
clauses.hostEvalVars, clauses.ifExpr,
- /*in_reduction_mod=*/nullptr,
/*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
clauses.mapVars, clauses.nowait, clauses.privateVars,
@@ -2540,8 +2525,7 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
- clauses.final, clauses.ifExpr, clauses.inReductionMod,
- clauses.inReductionVars,
+ clauses.final, clauses.ifExpr, clauses.inReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
clauses.priority, /*private_vars=*/clauses.privateVars,
@@ -2567,8 +2551,7 @@ void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
const TaskgroupOperands &clauses) {
MLIRContext *ctx = builder.getContext();
TaskgroupOp::build(builder, state, clauses.allocateVars,
- clauses.allocatorVars, clauses.taskReductionMod,
- clauses.taskReductionVars,
+ clauses.allocatorVars, clauses.taskReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
makeArrayAttr(ctx, clauses.taskReductionSyms));
}
@@ -2589,8 +2572,7 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
// TODO Store clauses in op: privateVars, privateSyms.
TaskloopOp::build(
builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionMod,
- clauses.inReductionVars,
+ clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index d21c3e513277f3..8d095287db064e 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1840,7 +1840,7 @@ combiner {
func.func @scan_test_1(%lb: i32, %ub: i32, %step: i32) {
%test1f32 = "test.f32"() : () -> (!llvm.ptr)
%test2f32 = "test.f32"() : () -> (!llvm.ptr)
- omp.wsloop reduction(Id:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+ omp.wsloop reduction(mod:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
// expected-error @below {{List item should appear in REDUCTION clause of the parent}}
omp.scan inclusive(%test2f32 : !llvm.ptr)
@@ -1866,7 +1866,7 @@ combiner {
func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
%test1f32 = "test.f32"() : () -> (!llvm.ptr)
- omp.wsloop reduction(Id:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+ omp.wsloop reduction(mod:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
// expected-error @below {{Exactly one of EXCLUSIVE or INCLUSIVE clause is expected}}
omp.scan inclusive(%test1f32 : !llvm.ptr) exclusive(%test1f32: !llvm.ptr)
@@ -1892,7 +1892,7 @@ combiner {
func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
%test1f32 = "test.f32"() : () -> (!llvm.ptr)
- omp.taskloop reduction(Id:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+ omp.taskloop reduction(mod:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
// expected-error @below {{SCAN directive needs to be enclosed within a parent worksharing loop construct or SIMD construct with INSCAN reduction modifier}}
omp.scan inclusive(%test1f32 : !llvm.ptr)
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 25610d030d3f1c..2c8e9ad3fe0700 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -904,16 +904,16 @@ func.func @wsloop_reduction(%lb : index, %ub : index, %step : index) {
func.func @wsloop_inscan_reduction(%lb : index, %ub : index, %step : index) {
%c1 = arith.constant 1 : i32
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
- // CHECK: reduction(Id: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
- omp.wsloop reduction(Id:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
+ // CHECK: reduction(mod: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
+ omp.wsloop reduction(mod:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
// CHECK: omp.scan inclusive(%{{.*}} : !llvm.ptr)
omp.scan inclusive(%0 : !llvm.ptr)
omp.yield
}
}
- // CHECK: reduction(Id: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
- omp.wsloop reduction(Id:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
+ // CHECK: reduction(mod: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
+ omp.wsloop reduction(mod:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
// CHECK: omp.scan exclusive(%{{.*}} : !llvm.ptr)
omp.scan exclusive(%0 : !llvm.ptr)
>From 5c984a048eefe479f664d287c32ea55e0d12128f Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Fri, 10 Jan 2025 14:15:49 -0600
Subject: [PATCH 4/5] R4: Adding memory write side effects to ScanOp and
addressing other review comments
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 9 +++---
.../mlir/Dialect/OpenMP/OpenMPEnums.td | 6 ++--
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 2 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 29 ++++++++++-------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 4 +++
mlir/test/Dialect/OpenMP/invalid.mlir | 6 ++--
mlir/test/Dialect/OpenMP/ops.mlir | 12 +++----
mlir/test/Target/LLVMIR/openmp-todo.mlir | 32 +++++++++++++++++++
8 files changed, 71 insertions(+), 29 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index d71dc70ed2d963..a8d97a36df79ee 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1188,10 +1188,11 @@ class OpenMP_ReductionClauseSkip<
// Description varies depending on the operation.
let description = [{
- Reductions can be performed by specifying reduction accumulator variables in
- `reduction_vars`, symbols referring to reduction declarations in the
- `reduction_syms` attribute, and whether the reduction variable should be
- passed into the reduction region by value or by reference in
+ Reductions can be performed by specifying the reduction modifer
+ (`default`, `inscan` or `task`) in `reduction_mod`, reduction accumulator
+ variables in `reduction_vars`, symbols referring to reduction declarations
+ in the `reduction_syms` attribute, and whether the reduction variable
+ should be passed into the reduction region by value or by reference in
`reduction_byref`. Each reduction is identified by the accumulator it uses
and accumulators must not be repeated in the same reduction. A private
variable corresponding to the accumulator is used in place of the
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 22777b4f0964b3..4fa7630166f5f8 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -183,9 +183,9 @@ def OrderModifierAttr : EnumAttr<OpenMP_Dialect, OrderModifier,
// reduction_modifier enum.
//===----------------------------------------------------------------------===//
-def ReductionModifierDefault : I32EnumAttrCase<"Default", 0>;
-def ReductionModifierInScan : I32EnumAttrCase<"InScan", 1>;
-def ReductionModifierTask : I32EnumAttrCase<"Task", 2>;
+def ReductionModifierDefault : I32EnumAttrCase<"defaultmod", 0>;
+def ReductionModifierInScan : I32EnumAttrCase<"inscan", 1>;
+def ReductionModifierTask : I32EnumAttrCase<"task", 2>;
def ReductionModifier : OpenMP_I32EnumAttr<
"ReductionModifier",
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index e653d0cf7f2f99..580c9c6ef6fde8 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1707,7 +1707,7 @@ def CancellationPointOp : OpenMP_Op<"cancellation_point", clauses = [
}
def ScanOp : OpenMP_Op<"scan", [
- AttrSizedOperandSegments
+ AttrSizedOperandSegments, MemoryEffects<[MemWrite]>
], clauses = [
OpenMP_InclusiveClause, OpenMP_ExclusiveClause]> {
let summary = "scan directive";
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 65b6496ea5b31d..a8bf501ef28e6e 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -535,8 +535,8 @@ static ParseResult parseClauseWithRegionArgs(
if (parser.parseLParen())
return failure();
- StringRef enumStr;
- if (succeeded(parser.parseOptionalKeyword("mod"))) {
+ if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
+ StringRef enumStr;
if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
parser.parseComma())
return failure();
@@ -3169,28 +3169,33 @@ LogicalResult ScanOp::verify() {
}
const OperandRange &scanVars =
hasExclusiveVars() ? getExclusiveVars() : getInclusiveVars();
- auto verifyScanVarsInReduction = [&scanVars](OperandRange reductionVars) {
+ auto verifyScanVarsInReduction = [&scanVars](ValueRange reductionVars) {
for (const auto &it : scanVars)
if (!llvm::is_contained(reductionVars, it))
return false;
return true;
};
- if (mlir::omp::WsloopOp parentOp =
+ if (mlir::omp::WsloopOp parentWsLoopOp =
(*this)->getParentOfType<mlir::omp::WsloopOp>()) {
- if (parentOp.getReductionModAttr() &&
- parentOp.getReductionModAttr().getValue() ==
- mlir::omp::ReductionModifier::InScan) {
- if (!verifyScanVarsInReduction(parentOp.getReductionVars())) {
+ if (parentWsLoopOp.getReductionModAttr() &&
+ parentWsLoopOp.getReductionModAttr().getValue() ==
+ mlir::omp::ReductionModifier::inscan) {
+ auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(
+ parentWsLoopOp.getOperation());
+ if (!verifyScanVarsInReduction(iface.getReductionBlockArgs())) {
return emitError(
"List item should appear in REDUCTION clause of the parent");
}
return success();
}
- } else if (mlir::omp::SimdOp parentOp =
+ } else if (mlir::omp::SimdOp parentSimdOp =
(*this)->getParentOfType<mlir::omp::SimdOp>()) {
- if (parentOp.getReductionModAttr().getValue() ==
- mlir::omp::ReductionModifier::InScan) {
- if (!verifyScanVarsInReduction(parentOp.getReductionVars())) {
+ if (parentSimdOp.getReductionModAttr() &&
+ parentSimdOp.getReductionModAttr().getValue() ==
+ mlir::omp::ReductionModifier::inscan) {
+ auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(
+ parentSimdOp.getOperation());
+ if (!verifyScanVarsInReduction(iface.getReductionBlockArgs())) {
return emitError(
"List item should appear in REDUCTION clause of the parent");
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 0be515e63b470c..3c456516eeea51 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4591,6 +4591,10 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
.Case([&](omp::AtomicCaptureOp op) {
return convertOmpAtomicCapture(op, builder, moduleTranslation);
})
+ .Case([&](omp::ScanOp) {
+ return op->emitError()
+ << "not yet implemented: " << op->getName() << " operation";
+ })
.Case([&](omp::SectionsOp) {
return convertOmpSections(*op, builder, moduleTranslation);
})
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 8d095287db064e..e3e6397d84d77c 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1840,7 +1840,7 @@ combiner {
func.func @scan_test_1(%lb: i32, %ub: i32, %step: i32) {
%test1f32 = "test.f32"() : () -> (!llvm.ptr)
%test2f32 = "test.f32"() : () -> (!llvm.ptr)
- omp.wsloop reduction(mod:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+ omp.wsloop reduction(mod:inscan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
// expected-error @below {{List item should appear in REDUCTION clause of the parent}}
omp.scan inclusive(%test2f32 : !llvm.ptr)
@@ -1866,7 +1866,7 @@ combiner {
func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
%test1f32 = "test.f32"() : () -> (!llvm.ptr)
- omp.wsloop reduction(mod:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+ omp.wsloop reduction(mod:inscan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
// expected-error @below {{Exactly one of EXCLUSIVE or INCLUSIVE clause is expected}}
omp.scan inclusive(%test1f32 : !llvm.ptr) exclusive(%test1f32: !llvm.ptr)
@@ -1892,7 +1892,7 @@ combiner {
func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
%test1f32 = "test.f32"() : () -> (!llvm.ptr)
- omp.taskloop reduction(mod:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+ omp.taskloop reduction(mod:inscan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
// expected-error @below {{SCAN directive needs to be enclosed within a parent worksharing loop construct or SIMD construct with INSCAN reduction modifier}}
omp.scan inclusive(%test1f32 : !llvm.ptr)
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 2c8e9ad3fe0700..c1259fabe82fba 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -904,19 +904,19 @@ func.func @wsloop_reduction(%lb : index, %ub : index, %step : index) {
func.func @wsloop_inscan_reduction(%lb : index, %ub : index, %step : index) {
%c1 = arith.constant 1 : i32
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
- // CHECK: reduction(mod: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
- omp.wsloop reduction(mod:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
+ // CHECK: reduction(mod: inscan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
+ omp.wsloop reduction(mod:inscan, @add_f32 %0 -> %prv : !llvm.ptr) {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
// CHECK: omp.scan inclusive(%{{.*}} : !llvm.ptr)
- omp.scan inclusive(%0 : !llvm.ptr)
+ omp.scan inclusive(%prv : !llvm.ptr)
omp.yield
}
}
- // CHECK: reduction(mod: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
- omp.wsloop reduction(mod:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
+ // CHECK: reduction(mod: inscan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
+ omp.wsloop reduction(mod:inscan, @add_f32 %0 -> %prv : !llvm.ptr) {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
// CHECK: omp.scan exclusive(%{{.*}} : !llvm.ptr)
- omp.scan exclusive(%0 : !llvm.ptr)
+ omp.scan exclusive(%prv : !llvm.ptr)
omp.yield
}
}
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 392a6558dcfa69..d5c43322c03469 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -186,6 +186,38 @@ llvm.func @simd_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
// -----
+omp.declare_reduction @add_f32 : f32
+init {
+^bb0(%arg: f32):
+ %0 = llvm.mlir.constant(0.0 : f32) : f32
+ omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+ %1 = llvm.fadd %arg0, %arg1 : f32
+ omp.yield (%1 : f32)
+}
+atomic {
+^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
+ %2 = llvm.load %arg3 : !llvm.ptr -> f32
+ llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
+ omp.yield
+}
+llvm.func @scan_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
+ // expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}
+ omp.wsloop reduction(mod:inscan, @add_f32 %x -> %prv : !llvm.ptr) {
+ omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+ // expected-error at below {{not yet implemented: omp.scan operation}}
+ // expected-error at below {{LLVM Translation failed for operation: omp.scan}}
+ omp.scan inclusive(%prv : !llvm.ptr)
+ omp.yield
+ }
+ }
+ llvm.return
+}
+
+// -----
+
llvm.func @single_allocate(%x : !llvm.ptr) {
// expected-error at below {{not yet implemented: Unhandled clause allocate in omp.single operation}}
// expected-error at below {{LLVM Translation failed for operation: omp.single}}
>From 2adadc13e18269e08a4a2c93b77d8f2c62cda1ce Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Wed, 15 Jan 2025 13:17:44 -0600
Subject: [PATCH 5/5] R5: Removing the check for matching scan variables and
reduction variables
---
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 25 +++---------------
mlir/test/Dialect/OpenMP/invalid.mlir | 27 --------------------
2 files changed, 3 insertions(+), 49 deletions(-)
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index a8bf501ef28e6e..aaa0eb8455d2a5 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3167,38 +3167,19 @@ LogicalResult ScanOp::verify() {
return emitError(
"Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
}
- const OperandRange &scanVars =
- hasExclusiveVars() ? getExclusiveVars() : getInclusiveVars();
- auto verifyScanVarsInReduction = [&scanVars](ValueRange reductionVars) {
- for (const auto &it : scanVars)
- if (!llvm::is_contained(reductionVars, it))
- return false;
- return true;
- };
if (mlir::omp::WsloopOp parentWsLoopOp =
(*this)->getParentOfType<mlir::omp::WsloopOp>()) {
if (parentWsLoopOp.getReductionModAttr() &&
parentWsLoopOp.getReductionModAttr().getValue() ==
mlir::omp::ReductionModifier::inscan) {
- auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(
- parentWsLoopOp.getOperation());
- if (!verifyScanVarsInReduction(iface.getReductionBlockArgs())) {
- return emitError(
- "List item should appear in REDUCTION clause of the parent");
- }
return success();
}
- } else if (mlir::omp::SimdOp parentSimdOp =
- (*this)->getParentOfType<mlir::omp::SimdOp>()) {
+ }
+ if (mlir::omp::SimdOp parentSimdOp =
+ (*this)->getParentOfType<mlir::omp::SimdOp>()) {
if (parentSimdOp.getReductionModAttr() &&
parentSimdOp.getReductionModAttr().getValue() ==
mlir::omp::ReductionModifier::inscan) {
- auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(
- parentSimdOp.getOperation());
- if (!verifyScanVarsInReduction(iface.getReductionBlockArgs())) {
- return emitError(
- "List item should appear in REDUCTION clause of the parent");
- }
return success();
}
}
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index e3e6397d84d77c..514792e425f89e 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1837,33 +1837,6 @@ combiner {
omp.yield (%1 : f32)
}
-func.func @scan_test_1(%lb: i32, %ub: i32, %step: i32) {
- %test1f32 = "test.f32"() : () -> (!llvm.ptr)
- %test2f32 = "test.f32"() : () -> (!llvm.ptr)
- omp.wsloop reduction(mod:inscan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
- omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
- // expected-error @below {{List item should appear in REDUCTION clause of the parent}}
- omp.scan inclusive(%test2f32 : !llvm.ptr)
- omp.yield
- }
- }
- return
-}
-
-// -----
-
-omp.declare_reduction @add_f32 : f32
-init {
- ^bb0(%arg: f32):
- %0 = arith.constant 0.0 : f32
- omp.yield (%0 : f32)
-}
-combiner {
- ^bb1(%arg0: f32, %arg1: f32):
- %1 = arith.addf %arg0, %arg1 : f32
- omp.yield (%1 : f32)
-}
-
func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
%test1f32 = "test.f32"() : () -> (!llvm.ptr)
omp.wsloop reduction(mod:inscan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
More information about the Mlir-commits
mailing list