[flang-commits] [flang] 88cdd99 - [flang] Add reduction semantics to fir.do_loop (#93934)
via flang-commits
flang-commits at lists.llvm.org
Thu Jun 6 11:16:44 PDT 2024
Author: khaki3
Date: 2024-06-06T11:16:40-07:00
New Revision: 88cdd9905597ace5b1ac7d080df5326d3399b3f8
URL: https://github.com/llvm/llvm-project/commit/88cdd9905597ace5b1ac7d080df5326d3399b3f8
DIFF: https://github.com/llvm/llvm-project/commit/88cdd9905597ace5b1ac7d080df5326d3399b3f8.diff
LOG: [flang] Add reduction semantics to fir.do_loop (#93934)
Derived from #92480. This PR introduces reduction semantics into loops
for DO CONCURRENT REDUCE. The `fir.do_loop` operation now invisibly has
the `operandSegmentsizes` attribute and takes variable-length reduction
operands with their operations given as `fir.reduce_attr`. For the sake
of compatibility, `fir.do_loop`'s builder has additional arguments at
the end. The `iter_args` operand should be placed in front of the
declaration of result types, so the new operand for reduction variables
(`reduce`) is put in the middle of arguments.
Added:
flang/test/Fir/loop03.fir
Modified:
flang/include/flang/Optimizer/Dialect/FIRAttr.td
flang/include/flang/Optimizer/Dialect/FIROps.td
flang/lib/Optimizer/Dialect/FIRAttr.cpp
flang/lib/Optimizer/Dialect/FIROps.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.td b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
index 0c34b640a5c9c..aedb6769186e9 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRAttr.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
@@ -67,6 +67,36 @@ def fir_BoxFieldAttr : I32EnumAttr<
let cppNamespace = "fir";
}
+def fir_ReduceOperationEnum : I32BitEnumAttr<"ReduceOperationEnum",
+ "intrinsic operations and functions supported by DO CONCURRENT REDUCE",
+ [
+ I32BitEnumAttrCaseBit<"Add", 0, "add">,
+ I32BitEnumAttrCaseBit<"Multiply", 1, "multiply">,
+ I32BitEnumAttrCaseBit<"AND", 2, "and">,
+ I32BitEnumAttrCaseBit<"OR", 3, "or">,
+ I32BitEnumAttrCaseBit<"EQV", 4, "eqv">,
+ I32BitEnumAttrCaseBit<"NEQV", 5, "neqv">,
+ I32BitEnumAttrCaseBit<"MAX", 6, "max">,
+ I32BitEnumAttrCaseBit<"MIN", 7, "min">,
+ I32BitEnumAttrCaseBit<"IAND", 8, "iand">,
+ I32BitEnumAttrCaseBit<"IOR", 9, "ior">,
+ I32BitEnumAttrCaseBit<"EIOR", 10, "eior">
+ ]> {
+ let separator = ", ";
+ let cppNamespace = "::fir";
+ let printBitEnumPrimaryGroups = 1;
+}
+
+def fir_ReduceAttr : fir_Attr<"Reduce"> {
+ let mnemonic = "reduce_attr";
+
+ let parameters = (ins
+ "ReduceOperationEnum":$reduce_operation
+ );
+
+ let assemblyFormat = "`<` $reduce_operation `>`";
+}
+
// mlir::SideEffects::Resource for modelling operations which add debugging information
def DebuggingResource : Resource<"::fir::DebuggingResource">;
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 37fbd1f9692a4..e7da3af5485cc 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2125,8 +2125,8 @@ class region_Op<string mnemonic, list<Trait> traits = []> :
let hasVerifier = 1;
}
-def fir_DoLoopOp : region_Op<"do_loop",
- [DeclareOpInterfaceMethods<LoopLikeOpInterface,
+def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getYieldedValuesMutable"]>]> {
let summary = "generalized loop operation";
let description = [{
@@ -2156,9 +2156,11 @@ def fir_DoLoopOp : region_Op<"do_loop",
Index:$lowerBound,
Index:$upperBound,
Index:$step,
+ Variadic<AnyType>:$reduceOperands,
Variadic<AnyType>:$initArgs,
OptionalAttr<UnitAttr>:$unordered,
- OptionalAttr<UnitAttr>:$finalValue
+ OptionalAttr<UnitAttr>:$finalValue,
+ OptionalAttr<ArrayAttr>:$reduceAttrs
);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
@@ -2169,6 +2171,8 @@ def fir_DoLoopOp : region_Op<"do_loop",
"mlir::Value":$step, CArg<"bool", "false">:$unordered,
CArg<"bool", "false">:$finalCountValue,
CArg<"mlir::ValueRange", "std::nullopt">:$iterArgs,
+ CArg<"mlir::ValueRange", "std::nullopt">:$reduceOperands,
+ CArg<"llvm::ArrayRef<mlir::Attribute>", "{}">:$reduceAttrs,
CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>
];
@@ -2181,11 +2185,12 @@ def fir_DoLoopOp : region_Op<"do_loop",
return getBody()->getArguments().drop_front();
}
mlir::Operation::operand_range getIterOperands() {
- return getOperands().drop_front(getNumControlOperands());
+ return getOperands()
+ .drop_front(getNumControlOperands() + getNumReduceOperands());
}
llvm::MutableArrayRef<mlir::OpOperand> getInitsMutable() {
- return
- getOperation()->getOpOperands().drop_front(getNumControlOperands());
+ return getOperation()->getOpOperands()
+ .drop_front(getNumControlOperands() + getNumReduceOperands());
}
void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }
@@ -2200,11 +2205,25 @@ def fir_DoLoopOp : region_Op<"do_loop",
unsigned getNumControlOperands() { return 3; }
/// Does the operation hold operands for loop-carried values
bool hasIterOperands() {
- return (*this)->getNumOperands() > getNumControlOperands();
+ return getNumIterOperands() > 0;
+ }
+ /// Does the operation hold operands for reduction variables
+ bool hasReduceOperands() {
+ return getNumReduceOperands() > 0;
+ }
+ /// Get Number of variadic operands
+ unsigned getNumOperands(unsigned idx) {
+ auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(
+ getOperandSegmentSizeAttr());
+ return static_cast<unsigned>(segments[idx]);
+ }
+ // Get Number of reduction operands
+ unsigned getNumReduceOperands() {
+ return getNumOperands(3);
}
/// Get Number of loop-carried values
unsigned getNumIterOperands() {
- return (*this)->getNumOperands() - getNumControlOperands();
+ return getNumOperands(4);
}
/// Get the body of the loop
diff --git a/flang/lib/Optimizer/Dialect/FIRAttr.cpp b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
index 2faba63dfba07..a0202a0159228 100644
--- a/flang/lib/Optimizer/Dialect/FIRAttr.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
@@ -297,6 +297,6 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
void FIROpsDialect::registerAttributes() {
addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr,
- LowerBoundAttr, PointIntervalAttr, RealAttr, SubclassAttr,
- UpperBoundAttr>();
+ LowerBoundAttr, PointIntervalAttr, RealAttr, ReduceAttr,
+ SubclassAttr, UpperBoundAttr>();
}
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index b530a9dc1bcc4..75ca738211abe 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -2456,9 +2456,16 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
mlir::OperationState &result, mlir::Value lb,
mlir::Value ub, mlir::Value step, bool unordered,
bool finalCountValue, mlir::ValueRange iterArgs,
+ mlir::ValueRange reduceOperands,
+ llvm::ArrayRef<mlir::Attribute> reduceAttrs,
llvm::ArrayRef<mlir::NamedAttribute> attributes) {
result.addOperands({lb, ub, step});
+ result.addOperands(reduceOperands);
result.addOperands(iterArgs);
+ result.addAttribute(getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr(
+ {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
+ static_cast<int32_t>(iterArgs.size())}));
if (finalCountValue) {
result.addTypes(builder.getIndexType());
result.addAttribute(getFinalValueAttrName(result.name),
@@ -2477,6 +2484,9 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
if (unordered)
result.addAttribute(getUnorderedAttrName(result.name),
builder.getUnitAttr());
+ if (!reduceAttrs.empty())
+ result.addAttribute(getReduceAttrsAttrName(result.name),
+ builder.getArrayAttr(reduceAttrs));
result.addAttributes(attributes);
}
@@ -2502,24 +2512,51 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
if (mlir::succeeded(parser.parseOptionalKeyword("unordered")))
result.addAttribute("unordered", builder.getUnitAttr());
+ // Parse the reduction arguments.
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands;
+ llvm::SmallVector<mlir::Type> reduceArgTypes;
+ if (succeeded(parser.parseOptionalKeyword("reduce"))) {
+ // Parse reduction attributes and variables.
+ llvm::SmallVector<ReduceAttr> attributes;
+ if (failed(parser.parseCommaSeparatedList(
+ mlir::AsmParser::Delimiter::Paren, [&]() {
+ if (parser.parseAttribute(attributes.emplace_back()) ||
+ parser.parseArrow() ||
+ parser.parseOperand(reduceOperands.emplace_back()) ||
+ parser.parseColonType(reduceArgTypes.emplace_back()))
+ return mlir::failure();
+ return mlir::success();
+ })))
+ return mlir::failure();
+ // Resolve input operands.
+ for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes))
+ if (parser.resolveOperand(std::get<0>(operand_type),
+ std::get<1>(operand_type), result.operands))
+ return mlir::failure();
+ llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+ attributes.end());
+ result.addAttribute(getReduceAttrsAttrName(result.name),
+ builder.getArrayAttr(arrayAttr));
+ }
+
// Parse the optional initial iteration arguments.
llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
- llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
llvm::SmallVector<mlir::Type> argTypes;
bool prependCount = false;
regionArgs.push_back(inductionVariable);
if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
// Parse assignment list and results type list.
- if (parser.parseAssignmentList(regionArgs, operands) ||
+ if (parser.parseAssignmentList(regionArgs, iterOperands) ||
parser.parseArrowTypeList(result.types))
return mlir::failure();
- if (result.types.size() == operands.size() + 1)
+ if (result.types.size() == iterOperands.size() + 1)
prependCount = true;
// Resolve input operands.
llvm::ArrayRef<mlir::Type> resTypes = result.types;
- for (auto operand_type :
- llvm::zip(operands, prependCount ? resTypes.drop_front() : resTypes))
+ for (auto operand_type : llvm::zip(
+ iterOperands, prependCount ? resTypes.drop_front() : resTypes))
if (parser.resolveOperand(std::get<0>(operand_type),
std::get<1>(operand_type), result.operands))
return mlir::failure();
@@ -2530,6 +2567,12 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
prependCount = true;
}
+ // Set the operandSegmentSizes attribute
+ result.addAttribute(getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr(
+ {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
+ static_cast<int32_t>(iterOperands.size())}));
+
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return mlir::failure();
@@ -2606,6 +2649,10 @@ mlir::LogicalResult fir::DoLoopOp::verify() {
i++;
}
+ auto reduceAttrs = getReduceAttrsAttr();
+ if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0))
+ return emitOpError(
+ "mismatch in number of reduction variables and reduction attributes");
return mlir::success();
}
@@ -2615,6 +2662,17 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
<< getUpperBound() << " step " << getStep();
if (getUnordered())
p << " unordered";
+ if (hasReduceOperands()) {
+ p << " reduce(";
+ auto attrs = getReduceAttrsAttr();
+ auto operands = getReduceOperands();
+ llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) {
+ p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
+ << std::get<1>(it).getType();
+ });
+ p << ')';
+ printBlockTerminators = true;
+ }
if (hasIterOperands()) {
p << " iter_args(";
auto regionArgs = getRegionIterArgs();
@@ -2628,8 +2686,9 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
p << " -> " << getResultTypes();
printBlockTerminators = true;
}
- p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
- {"unordered", "finalValue"});
+ p.printOptionalAttrDictWithKeyword(
+ (*this)->getAttrs(),
+ {"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"});
p << ' ';
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
printBlockTerminators);
diff --git a/flang/test/Fir/loop03.fir b/flang/test/Fir/loop03.fir
new file mode 100644
index 0000000000000..b88dcaf8639be
--- /dev/null
+++ b/flang/test/Fir/loop03.fir
@@ -0,0 +1,17 @@
+// Test the reduction semantics of fir.do_loop
+// RUN: fir-opt %s | FileCheck %s
+
+func.func @reduction() {
+ %bound = arith.constant 10 : index
+ %step = arith.constant 1 : index
+ %sum = fir.alloca i32
+// CHECK: %[[VAL_0:.*]] = fir.alloca i32
+// CHECK: fir.do_loop %[[VAL_1:.*]] = %[[VAL_2:.*]] to %[[VAL_3:.*]] step %[[VAL_4:.*]] unordered reduce(#fir.reduce_attr<add> -> %[[VAL_0]] : !fir.ref<i32>) {
+ fir.do_loop %iv = %step to %bound step %step unordered reduce(#fir.reduce_attr<add> -> %sum : !fir.ref<i32>) {
+ %index = fir.convert %iv : (index) -> i32
+ %1 = fir.load %sum : !fir.ref<i32>
+ %2 = arith.addi %index, %1 : i32
+ fir.store %2 to %sum : !fir.ref<i32>
+ }
+ return
+}
More information about the flang-commits
mailing list