[flang-commits] [flang] [flang] Add reduction semantics to fir.do_loop (PR #93934)
via flang-commits
flang-commits at lists.llvm.org
Fri May 31 01:01:04 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: None (khaki3)
<details>
<summary>Changes</summary>
Derived from #<!-- -->92480. This PR introduces `fir.reduce` into `fir.do_loop`. The operation `fir.reduce` conveys reduction semantics in a similar way to `acc.reduction`; it marks the reference to reduction variables while keeping their original names. The `fir.do_loop` operation now invisibly has the `operandSegmentsizes` attribute and takes variable-length reduction operands with their operations given as `fir.reduce_attr`. For the sake of compatibility, `fir.do_loop`'s builder has additional arguments at the end. The `iter_args` operand should be next to a return-type declaration, so the new operand for `fir.reduce` is put in the middle of arguments.
---
Full diff: https://github.com/llvm/llvm-project/pull/93934.diff
4 Files Affected:
- (modified) flang/include/flang/Optimizer/Dialect/FIRAttr.td (+30)
- (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+56-8)
- (modified) flang/lib/Optimizer/Dialect/FIRAttr.cpp (+2-2)
- (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+66-7)
``````````diff
diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.td b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
index 0c34b640a5c9c..aedb6769186e9 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRAttr.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
@@ -67,6 +67,36 @@ def fir_BoxFieldAttr : I32EnumAttr<
let cppNamespace = "fir";
}
+def fir_ReduceOperationEnum : I32BitEnumAttr<"ReduceOperationEnum",
+ "intrinsic operations and functions supported by DO CONCURRENT REDUCE",
+ [
+ I32BitEnumAttrCaseBit<"Add", 0, "add">,
+ I32BitEnumAttrCaseBit<"Multiply", 1, "multiply">,
+ I32BitEnumAttrCaseBit<"AND", 2, "and">,
+ I32BitEnumAttrCaseBit<"OR", 3, "or">,
+ I32BitEnumAttrCaseBit<"EQV", 4, "eqv">,
+ I32BitEnumAttrCaseBit<"NEQV", 5, "neqv">,
+ I32BitEnumAttrCaseBit<"MAX", 6, "max">,
+ I32BitEnumAttrCaseBit<"MIN", 7, "min">,
+ I32BitEnumAttrCaseBit<"IAND", 8, "iand">,
+ I32BitEnumAttrCaseBit<"IOR", 9, "ior">,
+ I32BitEnumAttrCaseBit<"EIOR", 10, "eior">
+ ]> {
+ let separator = ", ";
+ let cppNamespace = "::fir";
+ let printBitEnumPrimaryGroups = 1;
+}
+
+def fir_ReduceAttr : fir_Attr<"Reduce"> {
+ let mnemonic = "reduce_attr";
+
+ let parameters = (ins
+ "ReduceOperationEnum":$reduce_operation
+ );
+
+ let assemblyFormat = "`<` $reduce_operation `>`";
+}
+
// mlir::SideEffects::Resource for modelling operations which add debugging information
def DebuggingResource : Resource<"::fir::DebuggingResource">;
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 3afc97475db11..d79f2da916d05 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2107,8 +2107,37 @@ class region_Op<string mnemonic, list<Trait> traits = []> :
let hasVerifier = 1;
}
-def fir_DoLoopOp : region_Op<"do_loop",
- [DeclareOpInterfaceMethods<LoopLikeOpInterface,
+def fir_ReduceOp : fir_SimpleOp<"reduce", [NoMemoryEffect]> {
+ let summary = "Represent reduction semantics for the reduce clause";
+
+ let description = [{
+ Given the address of a variable, creates reduction information for the
+ reduce clause.
+
+ ```
+ %17 = fir.reduce %8 {name = "sum"} : (!fir.ref<f32>) -> !fir.ref<f32>
+ fir.do_loop ... unordered reduce(#fir.reduce_attr<add> -> %17 : !fir.ref<f32>) ...
+ ```
+
+ This operation is typically used for DO CONCURRENT REDUCE clause. The memref
+ operand may have a unique name while the `name` attribute preserves the
+ original name of a reduction variable.
+ }];
+
+ let arguments = (ins
+ AnyRefOrBoxLike:$memref,
+ Builtin_StringAttr:$name
+ );
+
+ let results = (outs AnyRefOrBox);
+
+ let assemblyFormat = [{
+ operands attr-dict `:` functional-type(operands, results)
+ }];
+}
+
+def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getYieldedValuesMutable"]>]> {
let summary = "generalized loop operation";
let description = [{
@@ -2138,9 +2167,11 @@ def fir_DoLoopOp : region_Op<"do_loop",
Index:$lowerBound,
Index:$upperBound,
Index:$step,
+ Variadic<AnyType>:$reduceOperands,
Variadic<AnyType>:$initArgs,
OptionalAttr<UnitAttr>:$unordered,
- OptionalAttr<UnitAttr>:$finalValue
+ OptionalAttr<UnitAttr>:$finalValue,
+ OptionalAttr<ArrayAttr>:$reduceAttrs
);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
@@ -2151,6 +2182,8 @@ def fir_DoLoopOp : region_Op<"do_loop",
"mlir::Value":$step, CArg<"bool", "false">:$unordered,
CArg<"bool", "false">:$finalCountValue,
CArg<"mlir::ValueRange", "std::nullopt">:$iterArgs,
+ CArg<"mlir::ValueRange", "std::nullopt">:$reduceOperands,
+ CArg<"llvm::ArrayRef<mlir::Attribute>", "{}">:$reduceAttrs,
CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>
];
@@ -2163,11 +2196,12 @@ def fir_DoLoopOp : region_Op<"do_loop",
return getBody()->getArguments().drop_front();
}
mlir::Operation::operand_range getIterOperands() {
- return getOperands().drop_front(getNumControlOperands());
+ return getOperands()
+ .drop_front(getNumControlOperands() + getNumReduceOperands());
}
llvm::MutableArrayRef<mlir::OpOperand> getInitsMutable() {
- return
- getOperation()->getOpOperands().drop_front(getNumControlOperands());
+ return getOperation()->getOpOperands()
+ .drop_front(getNumControlOperands() + getNumReduceOperands());
}
void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }
@@ -2182,11 +2216,25 @@ def fir_DoLoopOp : region_Op<"do_loop",
unsigned getNumControlOperands() { return 3; }
/// Does the operation hold operands for loop-carried values
bool hasIterOperands() {
- return (*this)->getNumOperands() > getNumControlOperands();
+ return getNumIterOperands() > 0;
+ }
+ /// Does the operation hold operands for reduction variables
+ bool hasReduceOperands() {
+ return getNumReduceOperands() > 0;
+ }
+ /// Get Number of variadic operands
+ unsigned getNumOperands(unsigned idx) {
+ auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(
+ getOperandSegmentSizeAttr());
+ return static_cast<unsigned>(segments[idx]);
+ }
+ // Get Number of reduction operands
+ unsigned getNumReduceOperands() {
+ return getNumOperands(3);
}
/// Get Number of loop-carried values
unsigned getNumIterOperands() {
- return (*this)->getNumOperands() - getNumControlOperands();
+ return getNumOperands(4);
}
/// Get the body of the loop
diff --git a/flang/lib/Optimizer/Dialect/FIRAttr.cpp b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
index 2faba63dfba07..a0202a0159228 100644
--- a/flang/lib/Optimizer/Dialect/FIRAttr.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
@@ -297,6 +297,6 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
void FIROpsDialect::registerAttributes() {
addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr,
- LowerBoundAttr, PointIntervalAttr, RealAttr, SubclassAttr,
- UpperBoundAttr>();
+ LowerBoundAttr, PointIntervalAttr, RealAttr, ReduceAttr,
+ SubclassAttr, UpperBoundAttr>();
}
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index b541b7cdc7a5b..807459c8ec3c7 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -2079,9 +2079,16 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
mlir::OperationState &result, mlir::Value lb,
mlir::Value ub, mlir::Value step, bool unordered,
bool finalCountValue, mlir::ValueRange iterArgs,
+ mlir::ValueRange reduceOperands,
+ llvm::ArrayRef<mlir::Attribute> reduceAttrs,
llvm::ArrayRef<mlir::NamedAttribute> attributes) {
result.addOperands({lb, ub, step});
+ result.addOperands(reduceOperands);
result.addOperands(iterArgs);
+ result.addAttribute(getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr(
+ {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
+ static_cast<int32_t>(iterArgs.size())}));
if (finalCountValue) {
result.addTypes(builder.getIndexType());
result.addAttribute(getFinalValueAttrName(result.name),
@@ -2100,6 +2107,9 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
if (unordered)
result.addAttribute(getUnorderedAttrName(result.name),
builder.getUnitAttr());
+ if (!reduceAttrs.empty())
+ result.addAttribute(getReduceAttrsAttrName(result.name),
+ builder.getArrayAttr(reduceAttrs));
result.addAttributes(attributes);
}
@@ -2125,24 +2135,51 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
if (mlir::succeeded(parser.parseOptionalKeyword("unordered")))
result.addAttribute("unordered", builder.getUnitAttr());
+ // Parse the reduction arguments.
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands;
+ llvm::SmallVector<mlir::Type> reduceArgTypes;
+ if (succeeded(parser.parseOptionalKeyword("reduce"))) {
+ // Parse reduction attributes and variables.
+ llvm::SmallVector<ReduceAttr> attributes;
+ if (failed(parser.parseCommaSeparatedList(
+ mlir::AsmParser::Delimiter::Paren, [&]() {
+ if (parser.parseAttribute(attributes.emplace_back()) ||
+ parser.parseArrow() ||
+ parser.parseOperand(reduceOperands.emplace_back()) ||
+ parser.parseColonType(reduceArgTypes.emplace_back()))
+ return mlir::failure();
+ return mlir::success();
+ })))
+ return mlir::failure();
+ // Resolve input operands.
+ for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes))
+ if (parser.resolveOperand(std::get<0>(operand_type),
+ std::get<1>(operand_type), result.operands))
+ return mlir::failure();
+ llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+ attributes.end());
+ result.addAttribute(getReduceAttrsAttrName(result.name),
+ builder.getArrayAttr(arrayAttr));
+ }
+
// Parse the optional initial iteration arguments.
llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
- llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
llvm::SmallVector<mlir::Type> argTypes;
bool prependCount = false;
regionArgs.push_back(inductionVariable);
if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
// Parse assignment list and results type list.
- if (parser.parseAssignmentList(regionArgs, operands) ||
+ if (parser.parseAssignmentList(regionArgs, iterOperands) ||
parser.parseArrowTypeList(result.types))
return mlir::failure();
- if (result.types.size() == operands.size() + 1)
+ if (result.types.size() == iterOperands.size() + 1)
prependCount = true;
// Resolve input operands.
llvm::ArrayRef<mlir::Type> resTypes = result.types;
- for (auto operand_type :
- llvm::zip(operands, prependCount ? resTypes.drop_front() : resTypes))
+ for (auto operand_type : llvm::zip(
+ iterOperands, prependCount ? resTypes.drop_front() : resTypes))
if (parser.resolveOperand(std::get<0>(operand_type),
std::get<1>(operand_type), result.operands))
return mlir::failure();
@@ -2153,6 +2190,12 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
prependCount = true;
}
+ // Set the operandSegmentSizes attribute
+ result.addAttribute(getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr(
+ {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
+ static_cast<int32_t>(iterOperands.size())}));
+
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return mlir::failure();
@@ -2229,6 +2272,10 @@ mlir::LogicalResult fir::DoLoopOp::verify() {
i++;
}
+ auto reduceAttrs = getReduceAttrsAttr();
+ if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0))
+ return emitOpError(
+ "mismatch in number of reduction variables and reduction attributes");
return mlir::success();
}
@@ -2238,6 +2285,17 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
<< getUpperBound() << " step " << getStep();
if (getUnordered())
p << " unordered";
+ if (hasReduceOperands()) {
+ p << " reduce(";
+ auto attrs = getReduceAttrsAttr();
+ auto operands = getReduceOperands();
+ llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) {
+ p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
+ << std::get<1>(it).getType();
+ });
+ p << ')';
+ printBlockTerminators = true;
+ }
if (hasIterOperands()) {
p << " iter_args(";
auto regionArgs = getRegionIterArgs();
@@ -2251,8 +2309,9 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
p << " -> " << getResultTypes();
printBlockTerminators = true;
}
- p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
- {"unordered", "finalValue"});
+ p.printOptionalAttrDictWithKeyword(
+ (*this)->getAttrs(),
+ {"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"});
p << ' ';
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
printBlockTerminators);
``````````
</details>
https://github.com/llvm/llvm-project/pull/93934
More information about the flang-commits
mailing list