[flang-commits] [flang] [flang] Add reduction semantics to fir.do_loop (PR #93934)
via flang-commits
flang-commits at lists.llvm.org
Wed Jun 5 10:48:14 PDT 2024
https://github.com/khaki3 updated https://github.com/llvm/llvm-project/pull/93934
>From 943119436397e5554eadf64688ad5a01205d0567 Mon Sep 17 00:00:00 2001
From: Kazuaki Matsumura <kmatsumura at nvidia.com>
Date: Thu, 30 May 2024 15:48:04 -0700
Subject: [PATCH 1/3] [flang] Add reduction semantics to fir.do_loop
---
.../flang/Optimizer/Dialect/FIRAttr.td | 30 ++++++++
.../include/flang/Optimizer/Dialect/FIROps.td | 64 ++++++++++++++--
flang/lib/Optimizer/Dialect/FIRAttr.cpp | 4 +-
flang/lib/Optimizer/Dialect/FIROps.cpp | 73 +++++++++++++++++--
4 files changed, 154 insertions(+), 17 deletions(-)
diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.td b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
index 0c34b640a5c9c..aedb6769186e9 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRAttr.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
@@ -67,6 +67,36 @@ def fir_BoxFieldAttr : I32EnumAttr<
let cppNamespace = "fir";
}
+def fir_ReduceOperationEnum : I32BitEnumAttr<"ReduceOperationEnum",
+ "intrinsic operations and functions supported by DO CONCURRENT REDUCE",
+ [
+ I32BitEnumAttrCaseBit<"Add", 0, "add">,
+ I32BitEnumAttrCaseBit<"Multiply", 1, "multiply">,
+ I32BitEnumAttrCaseBit<"AND", 2, "and">,
+ I32BitEnumAttrCaseBit<"OR", 3, "or">,
+ I32BitEnumAttrCaseBit<"EQV", 4, "eqv">,
+ I32BitEnumAttrCaseBit<"NEQV", 5, "neqv">,
+ I32BitEnumAttrCaseBit<"MAX", 6, "max">,
+ I32BitEnumAttrCaseBit<"MIN", 7, "min">,
+ I32BitEnumAttrCaseBit<"IAND", 8, "iand">,
+ I32BitEnumAttrCaseBit<"IOR", 9, "ior">,
+ I32BitEnumAttrCaseBit<"EIOR", 10, "eior">
+ ]> {
+ let separator = ", ";
+ let cppNamespace = "::fir";
+ let printBitEnumPrimaryGroups = 1;
+}
+
+def fir_ReduceAttr : fir_Attr<"Reduce"> {
+ let mnemonic = "reduce_attr";
+
+ let parameters = (ins
+ "ReduceOperationEnum":$reduce_operation
+ );
+
+ let assemblyFormat = "`<` $reduce_operation `>`";
+}
+
// mlir::SideEffects::Resource for modelling operations which add debugging information
def DebuggingResource : Resource<"::fir::DebuggingResource">;
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 3afc97475db11..d79f2da916d05 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2107,8 +2107,37 @@ class region_Op<string mnemonic, list<Trait> traits = []> :
let hasVerifier = 1;
}
-def fir_DoLoopOp : region_Op<"do_loop",
- [DeclareOpInterfaceMethods<LoopLikeOpInterface,
+def fir_ReduceOp : fir_SimpleOp<"reduce", [NoMemoryEffect]> {
+ let summary = "Represent reduction semantics for the reduce clause";
+
+ let description = [{
+ Given the address of a variable, creates reduction information for the
+ reduce clause.
+
+ ```
+ %17 = fir.reduce %8 {name = "sum"} : (!fir.ref<f32>) -> !fir.ref<f32>
+ fir.do_loop ... unordered reduce(#fir.reduce_attr<add> -> %17 : !fir.ref<f32>) ...
+ ```
+
+ This operation is typically used for DO CONCURRENT REDUCE clause. The memref
+ operand may have a unique name while the `name` attribute preserves the
+ original name of a reduction variable.
+ }];
+
+ let arguments = (ins
+ AnyRefOrBoxLike:$memref,
+ Builtin_StringAttr:$name
+ );
+
+ let results = (outs AnyRefOrBox);
+
+ let assemblyFormat = [{
+ operands attr-dict `:` functional-type(operands, results)
+ }];
+}
+
+def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getYieldedValuesMutable"]>]> {
let summary = "generalized loop operation";
let description = [{
@@ -2138,9 +2167,11 @@ def fir_DoLoopOp : region_Op<"do_loop",
Index:$lowerBound,
Index:$upperBound,
Index:$step,
+ Variadic<AnyType>:$reduceOperands,
Variadic<AnyType>:$initArgs,
OptionalAttr<UnitAttr>:$unordered,
- OptionalAttr<UnitAttr>:$finalValue
+ OptionalAttr<UnitAttr>:$finalValue,
+ OptionalAttr<ArrayAttr>:$reduceAttrs
);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
@@ -2151,6 +2182,8 @@ def fir_DoLoopOp : region_Op<"do_loop",
"mlir::Value":$step, CArg<"bool", "false">:$unordered,
CArg<"bool", "false">:$finalCountValue,
CArg<"mlir::ValueRange", "std::nullopt">:$iterArgs,
+ CArg<"mlir::ValueRange", "std::nullopt">:$reduceOperands,
+ CArg<"llvm::ArrayRef<mlir::Attribute>", "{}">:$reduceAttrs,
CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>
];
@@ -2163,11 +2196,12 @@ def fir_DoLoopOp : region_Op<"do_loop",
return getBody()->getArguments().drop_front();
}
mlir::Operation::operand_range getIterOperands() {
- return getOperands().drop_front(getNumControlOperands());
+ return getOperands()
+ .drop_front(getNumControlOperands() + getNumReduceOperands());
}
llvm::MutableArrayRef<mlir::OpOperand> getInitsMutable() {
- return
- getOperation()->getOpOperands().drop_front(getNumControlOperands());
+ return getOperation()->getOpOperands()
+ .drop_front(getNumControlOperands() + getNumReduceOperands());
}
void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }
@@ -2182,11 +2216,25 @@ def fir_DoLoopOp : region_Op<"do_loop",
unsigned getNumControlOperands() { return 3; }
/// Does the operation hold operands for loop-carried values
bool hasIterOperands() {
- return (*this)->getNumOperands() > getNumControlOperands();
+ return getNumIterOperands() > 0;
+ }
+ /// Does the operation hold operands for reduction variables
+ bool hasReduceOperands() {
+ return getNumReduceOperands() > 0;
+ }
+ /// Get Number of variadic operands
+ unsigned getNumOperands(unsigned idx) {
+ auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(
+ getOperandSegmentSizeAttr());
+ return static_cast<unsigned>(segments[idx]);
+ }
+ // Get Number of reduction operands
+ unsigned getNumReduceOperands() {
+ return getNumOperands(3);
}
/// Get Number of loop-carried values
unsigned getNumIterOperands() {
- return (*this)->getNumOperands() - getNumControlOperands();
+ return getNumOperands(4);
}
/// Get the body of the loop
diff --git a/flang/lib/Optimizer/Dialect/FIRAttr.cpp b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
index 2faba63dfba07..a0202a0159228 100644
--- a/flang/lib/Optimizer/Dialect/FIRAttr.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
@@ -297,6 +297,6 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
void FIROpsDialect::registerAttributes() {
addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr,
- LowerBoundAttr, PointIntervalAttr, RealAttr, SubclassAttr,
- UpperBoundAttr>();
+ LowerBoundAttr, PointIntervalAttr, RealAttr, ReduceAttr,
+ SubclassAttr, UpperBoundAttr>();
}
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index b541b7cdc7a5b..807459c8ec3c7 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -2079,9 +2079,16 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
mlir::OperationState &result, mlir::Value lb,
mlir::Value ub, mlir::Value step, bool unordered,
bool finalCountValue, mlir::ValueRange iterArgs,
+ mlir::ValueRange reduceOperands,
+ llvm::ArrayRef<mlir::Attribute> reduceAttrs,
llvm::ArrayRef<mlir::NamedAttribute> attributes) {
result.addOperands({lb, ub, step});
+ result.addOperands(reduceOperands);
result.addOperands(iterArgs);
+ result.addAttribute(getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr(
+ {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
+ static_cast<int32_t>(iterArgs.size())}));
if (finalCountValue) {
result.addTypes(builder.getIndexType());
result.addAttribute(getFinalValueAttrName(result.name),
@@ -2100,6 +2107,9 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
if (unordered)
result.addAttribute(getUnorderedAttrName(result.name),
builder.getUnitAttr());
+ if (!reduceAttrs.empty())
+ result.addAttribute(getReduceAttrsAttrName(result.name),
+ builder.getArrayAttr(reduceAttrs));
result.addAttributes(attributes);
}
@@ -2125,24 +2135,51 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
if (mlir::succeeded(parser.parseOptionalKeyword("unordered")))
result.addAttribute("unordered", builder.getUnitAttr());
+ // Parse the reduction arguments.
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands;
+ llvm::SmallVector<mlir::Type> reduceArgTypes;
+ if (succeeded(parser.parseOptionalKeyword("reduce"))) {
+ // Parse reduction attributes and variables.
+ llvm::SmallVector<ReduceAttr> attributes;
+ if (failed(parser.parseCommaSeparatedList(
+ mlir::AsmParser::Delimiter::Paren, [&]() {
+ if (parser.parseAttribute(attributes.emplace_back()) ||
+ parser.parseArrow() ||
+ parser.parseOperand(reduceOperands.emplace_back()) ||
+ parser.parseColonType(reduceArgTypes.emplace_back()))
+ return mlir::failure();
+ return mlir::success();
+ })))
+ return mlir::failure();
+ // Resolve input operands.
+ for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes))
+ if (parser.resolveOperand(std::get<0>(operand_type),
+ std::get<1>(operand_type), result.operands))
+ return mlir::failure();
+ llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+ attributes.end());
+ result.addAttribute(getReduceAttrsAttrName(result.name),
+ builder.getArrayAttr(arrayAttr));
+ }
+
// Parse the optional initial iteration arguments.
llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
- llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
llvm::SmallVector<mlir::Type> argTypes;
bool prependCount = false;
regionArgs.push_back(inductionVariable);
if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
// Parse assignment list and results type list.
- if (parser.parseAssignmentList(regionArgs, operands) ||
+ if (parser.parseAssignmentList(regionArgs, iterOperands) ||
parser.parseArrowTypeList(result.types))
return mlir::failure();
- if (result.types.size() == operands.size() + 1)
+ if (result.types.size() == iterOperands.size() + 1)
prependCount = true;
// Resolve input operands.
llvm::ArrayRef<mlir::Type> resTypes = result.types;
- for (auto operand_type :
- llvm::zip(operands, prependCount ? resTypes.drop_front() : resTypes))
+ for (auto operand_type : llvm::zip(
+ iterOperands, prependCount ? resTypes.drop_front() : resTypes))
if (parser.resolveOperand(std::get<0>(operand_type),
std::get<1>(operand_type), result.operands))
return mlir::failure();
@@ -2153,6 +2190,12 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
prependCount = true;
}
+ // Set the operandSegmentSizes attribute
+ result.addAttribute(getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr(
+ {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
+ static_cast<int32_t>(iterOperands.size())}));
+
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return mlir::failure();
@@ -2229,6 +2272,10 @@ mlir::LogicalResult fir::DoLoopOp::verify() {
i++;
}
+ auto reduceAttrs = getReduceAttrsAttr();
+ if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0))
+ return emitOpError(
+ "mismatch in number of reduction variables and reduction attributes");
return mlir::success();
}
@@ -2238,6 +2285,17 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
<< getUpperBound() << " step " << getStep();
if (getUnordered())
p << " unordered";
+ if (hasReduceOperands()) {
+ p << " reduce(";
+ auto attrs = getReduceAttrsAttr();
+ auto operands = getReduceOperands();
+ llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) {
+ p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
+ << std::get<1>(it).getType();
+ });
+ p << ')';
+ printBlockTerminators = true;
+ }
if (hasIterOperands()) {
p << " iter_args(";
auto regionArgs = getRegionIterArgs();
@@ -2251,8 +2309,9 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
p << " -> " << getResultTypes();
printBlockTerminators = true;
}
- p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
- {"unordered", "finalValue"});
+ p.printOptionalAttrDictWithKeyword(
+ (*this)->getAttrs(),
+ {"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"});
p << ' ';
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
printBlockTerminators);
>From 74c06ae6f302813ef9a128b05ddbf70912d7e0b8 Mon Sep 17 00:00:00 2001
From: Kazuaki Matsumura <kmatsumura at nvidia.com>
Date: Mon, 3 Jun 2024 08:36:42 -0700
Subject: [PATCH 2/3] [flang] Add test/Fir/loop03.fir to test the reduction
semantics of fir.do_loop
---
flang/test/Fir/loop03.fir | 19 +++++++++++++++++++
1 file changed, 19 insertions(+)
create mode 100644 flang/test/Fir/loop03.fir
diff --git a/flang/test/Fir/loop03.fir b/flang/test/Fir/loop03.fir
new file mode 100644
index 0000000000000..916ccaeaa2aef
--- /dev/null
+++ b/flang/test/Fir/loop03.fir
@@ -0,0 +1,19 @@
+// Test the reduction semantics of fir.do_loop
+// RUN: fir-opt %s | FileCheck %s
+
+func.func @reduction() {
+ %bound = arith.constant 10 : index
+ %step = arith.constant 1 : index
+ %sum = fir.alloca i32
+ %red = fir.reduce %sum {name = "sum"} : (!fir.ref<i32>) -> !fir.ref<i32>
+// CHECK: %[[VAL_0:.*]] = fir.alloca i32
+// CHECK: %[[VAL_1:.*]] = fir.reduce %[[VAL_0]] {name = "sum"} : (!fir.ref<i32>) -> !fir.ref<i32>
+// CHECK: fir.do_loop %[[VAL_2:.*]] = %[[VAL_3:.*]] to %[[VAL_4:.*]] step %[[VAL_5:.*]] unordered reduce(#fir.reduce_attr<add> -> %[[VAL_1]] : !fir.ref<i32>) {
+ fir.do_loop %iv = %step to %bound step %step unordered reduce(#fir.reduce_attr<add> -> %red : !fir.ref<i32>) {
+ %index = fir.convert %iv : (index) -> i32
+ %1 = fir.load %sum : !fir.ref<i32>
+ %2 = arith.addi %index, %1 : i32
+ fir.store %2 to %sum : !fir.ref<i32>
+ }
+ return
+}
>From 0b655d1cd476efb065e83ea15ce6821a4b49132a Mon Sep 17 00:00:00 2001
From: Kazuaki Matsumura <kmatsumura at nvidia.com>
Date: Mon, 3 Jun 2024 09:10:32 -0700
Subject: [PATCH 3/3] [flang] Remove fir.reduce
---
.../include/flang/Optimizer/Dialect/FIROps.td | 29 -------------------
flang/test/Fir/loop03.fir | 6 ++--
2 files changed, 2 insertions(+), 33 deletions(-)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index d79f2da916d05..0a7bd4178517a 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2107,35 +2107,6 @@ class region_Op<string mnemonic, list<Trait> traits = []> :
let hasVerifier = 1;
}
-def fir_ReduceOp : fir_SimpleOp<"reduce", [NoMemoryEffect]> {
- let summary = "Represent reduction semantics for the reduce clause";
-
- let description = [{
- Given the address of a variable, creates reduction information for the
- reduce clause.
-
- ```
- %17 = fir.reduce %8 {name = "sum"} : (!fir.ref<f32>) -> !fir.ref<f32>
- fir.do_loop ... unordered reduce(#fir.reduce_attr<add> -> %17 : !fir.ref<f32>) ...
- ```
-
- This operation is typically used for DO CONCURRENT REDUCE clause. The memref
- operand may have a unique name while the `name` attribute preserves the
- original name of a reduction variable.
- }];
-
- let arguments = (ins
- AnyRefOrBoxLike:$memref,
- Builtin_StringAttr:$name
- );
-
- let results = (outs AnyRefOrBox);
-
- let assemblyFormat = [{
- operands attr-dict `:` functional-type(operands, results)
- }];
-}
-
def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getYieldedValuesMutable"]>]> {
diff --git a/flang/test/Fir/loop03.fir b/flang/test/Fir/loop03.fir
index 916ccaeaa2aef..b88dcaf8639be 100644
--- a/flang/test/Fir/loop03.fir
+++ b/flang/test/Fir/loop03.fir
@@ -5,11 +5,9 @@ func.func @reduction() {
%bound = arith.constant 10 : index
%step = arith.constant 1 : index
%sum = fir.alloca i32
- %red = fir.reduce %sum {name = "sum"} : (!fir.ref<i32>) -> !fir.ref<i32>
// CHECK: %[[VAL_0:.*]] = fir.alloca i32
-// CHECK: %[[VAL_1:.*]] = fir.reduce %[[VAL_0]] {name = "sum"} : (!fir.ref<i32>) -> !fir.ref<i32>
-// CHECK: fir.do_loop %[[VAL_2:.*]] = %[[VAL_3:.*]] to %[[VAL_4:.*]] step %[[VAL_5:.*]] unordered reduce(#fir.reduce_attr<add> -> %[[VAL_1]] : !fir.ref<i32>) {
- fir.do_loop %iv = %step to %bound step %step unordered reduce(#fir.reduce_attr<add> -> %red : !fir.ref<i32>) {
+// CHECK: fir.do_loop %[[VAL_1:.*]] = %[[VAL_2:.*]] to %[[VAL_3:.*]] step %[[VAL_4:.*]] unordered reduce(#fir.reduce_attr<add> -> %[[VAL_0]] : !fir.ref<i32>) {
+ fir.do_loop %iv = %step to %bound step %step unordered reduce(#fir.reduce_attr<add> -> %sum : !fir.ref<i32>) {
%index = fir.convert %iv : (index) -> i32
%1 = fir.load %sum : !fir.ref<i32>
%2 = arith.addi %index, %1 : i32
More information about the flang-commits
mailing list