[Mlir-commits] [mlir] inscan reduction and scan op mlir support (PR #114737)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 3 21:06:51 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-flang-openmp
Author: Anchu Rajendran S (anchuraj)
<details>
<summary>Changes</summary>
Scan directive allows to specify scan reductions within an worksharing loop, worksharing loop simd or simd directive which should have an `InScan` modifier associated with it. This change adds the mlir support for the same.
---
Patch is 26.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114737.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+57)
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td (+21)
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+26-6)
- (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+1)
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+95-33)
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+23)
``````````diff
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 886554f66afffc..b45d89463639c5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -283,6 +283,34 @@ class OpenMP_DoacrossClauseSkip<
def OpenMP_DoacrossClause : OpenMP_DoacrossClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// V5.2: [5.4.7] `exclusive` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_ExclusiveClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+ let arguments = (ins
+ Variadic<AnyType>:$exclusive_vars
+ );
+
+ let optAssemblyFormat = [{
+ `exclusive` `(` $exclusive_vars `:` type($exclusive_vars) `)`
+ }];
+
+ let description = [{
+ The exclusive clause is used on a separating directive that separates a
+ structured block into two structured block sequences. If it
+ is specified, the input phase excludes the preceding structured block
+ sequence and instead includes the following structured block sequence,
+ while the scan phase includes the preceding structured block sequence.
+ }];
+}
+
+def OpenMP_ExclusiveClause : OpenMP_ExclusiveClauseSkip<>;
+
//===----------------------------------------------------------------------===//
// V5.2: [10.5.1] `filter` clause
//===----------------------------------------------------------------------===//
@@ -393,6 +421,34 @@ class OpenMP_HasDeviceAddrClauseSkip<
def OpenMP_HasDeviceAddrClause : OpenMP_HasDeviceAddrClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// V5.2: [5.4.7] `inclusive` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_InclusiveClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+ let arguments = (ins
+ Variadic<AnyType>:$inclusive_vars
+ );
+
+ let optAssemblyFormat = [{
+ `inclusive` `(` $inclusive_vars `:` type($inclusive_vars) `)`
+ }];
+
+ let description = [{
+ The inclusive clause is used on a separating directive that separates a
+ structured block into two structured block sequences. If it is specified,
+ the input phase includes the preceding structured block sequence and the
+ scan phase includes the following structured block sequence.
+ }];
+}
+
+def OpenMP_InclusiveClause : OpenMP_InclusiveClauseSkip<>;
+
+
//===----------------------------------------------------------------------===//
// V5.2: [15.1.2] `hint` clause
//===----------------------------------------------------------------------===//
@@ -983,6 +1039,7 @@ class OpenMP_ReductionClauseSkip<
];
let arguments = (ins
+ OptionalAttr<ReductionModifierAttr>:$reduction_mod,
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$reduction_byref,
OptionalAttr<SymbolRefArrayAttr>:$reduction_syms
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index b1a9e3330522b2..23086556bbb2f5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -178,6 +178,27 @@ def OrderModifier
def OrderModifierAttr : EnumAttr<OpenMP_Dialect, OrderModifier,
"order_mod">;
+//===----------------------------------------------------------------------===//
+// reduction_modifier enum.
+//===----------------------------------------------------------------------===//
+
+def ReductionModifierInScan : I32EnumAttrCase<"InScan", 0>;
+def ReductionModifierTask : I32EnumAttrCase<"Task", 1>;
+def ReductionModifierDefault : I32EnumAttrCase<"Default", 2>;
+
+def ReductionModifier : OpenMP_I32EnumAttr<
+ "ReductionModifier",
+ "reduction modifier", [
+ ReductionModifierInScan,
+ ReductionModifierTask,
+ ReductionModifierDefault
+ ]>;
+
+def ReductionModifierAttr : OpenMP_EnumAttr<ReductionModifier,
+ "reduction_modifier"> {
+ let assemblyFormat = "`(` $value `)`";
+}
+
//===----------------------------------------------------------------------===//
// sched_mod enum.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 626539cb7bde42..a03f18a816c39e 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -170,7 +170,7 @@ def ParallelOp : OpenMP_Op<"parallel", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
- $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+ $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
}];
@@ -215,7 +215,7 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
- $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+ $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
}];
@@ -274,7 +274,7 @@ def SectionsOp : OpenMP_Op<"sections", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
- $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+ $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
}];
@@ -422,7 +422,7 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
- $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+ $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
}];
@@ -476,7 +476,7 @@ def SimdOp : OpenMP_Op<"simd", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
- $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+ $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
}];
@@ -680,7 +680,7 @@ def TaskloopOp : OpenMP_Op<"taskloop", traits = [
custom<InReductionPrivateReductionRegion>(
$region, $in_reduction_vars, type($in_reduction_vars),
$in_reduction_byref, $in_reduction_syms, $private_vars,
- type($private_vars), $private_syms, $reduction_vars,
+ type($private_vars), $private_syms, $reduction_mod, $reduction_vars,
type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
}];
@@ -1560,6 +1560,26 @@ def CancellationPointOp : OpenMP_Op<"cancellation_point", clauses = [
let hasVerifier = 1;
}
+def ScanOp : OpenMP_Op<"scan", [
+ AttrSizedOperandSegments, RecipeInterface, IsolatedFromAbove
+ ], clauses = [
+ OpenMP_InclusiveClause, OpenMP_ExclusiveClause]> {
+ let summary = "scan directive";
+ let description = [{
+ The scan directive allows to specify scan reduction. Scan directive
+ should be enclosed with in a parent directive along with which , a
+ reduction clause with `InScan` modifier must be specified. Scan directive
+ allows to separate code blocks to input phase and scan phase in the region
+ enclosed by the parent.
+ }] # clausesDescription;
+
+ let builders = [
+ OpBuilder<(ins CArg<"const ScanOperands &">:$clauses)>
+ ];
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// 2.19.5.7 declare reduction Directive
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index aa241b91d758ca..233739e1d6d917 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -451,6 +451,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
/* proc_bind_kind = */ omp::ClauseProcBindKindAttr{},
+ /* reduction_mod = */ nullptr,
/* reduction_vars = */ llvm::SmallVector<Value>{},
/* reduction_byref = */ DenseBoolArrayAttr{},
/* reduction_syms = */ ArrayAttr{});
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index e1df647d6a3c71..0ad7fe2c2cf243 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -491,16 +491,27 @@ struct PrivateParseArgs {
SmallVectorImpl<Type> &types, ArrayAttr &syms)
: vars(vars), types(types), syms(syms) {}
};
+
+static ReductionModifierAttr nullReductionMod = nullptr;
struct ReductionParseArgs {
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
SmallVectorImpl<Type> &types;
DenseBoolArrayAttr &byref;
ArrayAttr &syms;
+ ReductionModifierAttr &reductionMod;
+ ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
+ SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
+ ArrayAttr &syms, ReductionModifierAttr &redMod)
+ : vars(vars), types(types), byref(byref), syms(syms),
+ reductionMod(redMod) {}
ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
ArrayAttr &syms)
- : vars(vars), types(types), byref(byref), syms(syms) {}
+ : vars(vars), types(types), byref(byref), syms(syms),
+ reductionMod(nullReductionMod) {}
};
+
+// specifies the arguments needs for `reduction` clause
struct AllRegionParseArgs {
std::optional<ReductionParseArgs> inReductionArgs;
std::optional<MapParseArgs> mapArgs;
@@ -517,7 +528,8 @@ static ParseResult parseClauseWithRegionArgs(
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types,
SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs,
- ArrayAttr *symbols = nullptr, DenseBoolArrayAttr *byref = nullptr) {
+ ArrayAttr *symbols = nullptr, DenseBoolArrayAttr *byref = nullptr,
+ ReductionModifierAttr &reductionMod = nullReductionMod) {
SmallVector<SymbolRefAttr> symbolVec;
SmallVector<bool> isByRefVec;
unsigned regionArgOffset = regionPrivateArgs.size();
@@ -525,6 +537,16 @@ static ParseResult parseClauseWithRegionArgs(
if (parser.parseLParen())
return failure();
+ StringRef enumStr;
+ if (succeeded(parser.parseOptionalKeyword("type"))) {
+ if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
+ parser.parseComma())
+ return failure();
+ std::optional<ReductionModifier> enumValue =
+ symbolizeReductionModifier(enumStr);
+ reductionMod = ReductionModifierAttr::get(parser.getContext(), *enumValue);
+ }
+
if (parser.parseCommaSeparatedList([&]() {
if (byref)
isByRefVec.push_back(
@@ -615,15 +637,14 @@ static ParseResult parseBlockArgClause(
if (succeeded(parser.parseOptionalKeyword(keyword))) {
if (!reductionArgs)
return failure();
-
if (failed(parseClauseWithRegionArgs(
parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
- &reductionArgs->syms, &reductionArgs->byref)))
+ &reductionArgs->syms, &reductionArgs->byref,
+ reductionArgs->reductionMod)))
return failure();
}
return success();
}
-
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion,
AllRegionParseArgs args) {
llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;
@@ -704,6 +725,7 @@ static ParseResult parseInReductionPrivateReductionRegion(
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+ ReductionModifierAttr &reductionMod,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
ArrayAttr &reductionSyms) {
@@ -712,7 +734,7 @@ static ParseResult parseInReductionPrivateReductionRegion(
inReductionByref, inReductionSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
- reductionSyms);
+ reductionSyms, reductionMod);
return parseBlockArgRegion(parser, region, args);
}
@@ -729,13 +751,14 @@ static ParseResult parsePrivateReductionRegion(
OpAsmParser &parser, Region ®ion,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+ ReductionModifierAttr &reductionMod,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
ArrayAttr &reductionSyms) {
AllRegionParseArgs args;
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
- reductionSyms);
+ reductionSyms, reductionMod);
return parseBlockArgRegion(parser, region, args);
}
@@ -784,9 +807,12 @@ struct ReductionPrintArgs {
TypeRange types;
DenseBoolArrayAttr byref;
ArrayAttr syms;
+ ReductionModifierAttr reductionMod;
ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
- ArrayAttr syms)
- : vars(vars), types(types), byref(byref), syms(syms) {}
+ ArrayAttr syms,
+ ReductionModifierAttr reductionMod = nullReductionMod)
+ : vars(vars), types(types), byref(byref), syms(syms),
+ reductionMod(reductionMod) {}
};
struct AllRegionPrintArgs {
std::optional<ReductionPrintArgs> inReductionArgs;
@@ -799,17 +825,21 @@ struct AllRegionPrintArgs {
};
} // namespace
-static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
- StringRef clauseName,
- ValueRange argsSubrange,
- ValueRange operands, TypeRange types,
- ArrayAttr symbols = nullptr,
- DenseBoolArrayAttr byref = nullptr) {
+static void printClauseWithRegionArgs(
+ OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
+ ValueRange argsSubrange, ValueRange operands, TypeRange types,
+ ArrayAttr symbols = nullptr, DenseBoolArrayAttr byref = nullptr,
+ ReductionModifierAttr reductionMod = nullptr) {
if (argsSubrange.empty())
return;
p << clauseName << "(";
+ if (reductionMod) {
+ p << "type: " << stringifyReductionModifier(reductionMod.getValue())
+ << ", ";
+ }
+
if (!symbols) {
llvm::SmallVector<Attribute> values(operands.size(), nullptr);
symbols = ArrayAttr::get(ctx, values);
@@ -859,7 +889,8 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
if (reductionArgs)
printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
reductionArgs->vars, reductionArgs->types,
- reductionArgs->syms, reductionArgs->byref);
+ reductionArgs->syms, reductionArgs->byref,
+ reductionArgs->reductionMod);
}
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
@@ -916,14 +947,15 @@ static void printInReductionPrivateReductionRegion(
OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
- ArrayAttr privateSyms, ValueRange reductionVars, TypeRange reductionTypes,
+ ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
+ ValueRange reductionVars, TypeRange reductionTypes,
DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
AllRegionPrintArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
- reductionSyms);
+ reductionSyms, reductionMod);
printBlockArgRegion(p, op, region, args);
}
@@ -937,13 +969,14 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
static void printPrivateReductionRegion(
OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars,
- TypeRange privateTypes, ArrayAttr privateSyms, ValueRange reductionVars,
+ TypeRange privateTypes, ArrayAttr privateSyms,
+ ReductionModifierAttr reductionMod, ValueRange reductionVars,
TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
ArrayAttr reductionSyms) {
AllRegionPrintArgs args;
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
- reductionSyms);
+ reductionSyms, reductionMod);
printBlockArgRegion(p, op, region, args);
}
@@ -1700,7 +1733,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
/*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr,
- /*reduction_vars=*/ValueRange(),
+ /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
/*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
state.addAttributes(attributes);
}
@@ -1711,7 +1744,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
clauses.ifExpr, clauses.numThreads, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
- clauses.procBindKind, clauses.reductionVars,
+ clauses.procBindKind, clauses.reductionMod,
+ clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms));
}
@@ -1810,12 +1844,13 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
const TeamsOperands &clauses) {
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: privateVars, privateSyms.
- TeamsOp::build(
- builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
- /*private_vars=*/{}, /*private_syms=*/nullptr, cla...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/114737
More information about the Mlir-commits
mailing list