[flang-commits] [flang] [Flang][OpenMP] NFC: Refactor reduction code (PR #79876)
via flang-commits
flang-commits at lists.llvm.org
Mon Jan 29 10:15:34 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-flang-openmp
Author: Kiran Chandramohan (kiranchandramohan)
<details>
<summary>Changes</summary>
Introduces a new enumeration to list all Fortran reduction identifiers. Moves the combiner code-generation into a separate function for possible reuse in array context in future.
---
Patch is 23.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79876.diff
1 Files Affected:
- (modified) flang/lib/Lower/OpenMP.cpp (+170-167)
``````````diff
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 7dd25f75d9eb76f..52d222f3d601f6a 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -726,21 +726,59 @@ static void checkMapType(mlir::Location location, mlir::Type type) {
class ReductionProcessor {
public:
- enum IntrinsicProc { MAX, MIN, IAND, IOR, IEOR };
- static IntrinsicProc
+ // TODO: Move this enumeration to the OpenMP dialect
+ enum ReductionIdentifier {
+ ID,
+ USER_DEF_OP,
+ ADD,
+ SUBTRACT,
+ MULTIPLY,
+ AND,
+ OR,
+ EQV,
+ NEQV,
+ MAX,
+ MIN,
+ IAND,
+ IOR,
+ IEOR
+ };
+ static ReductionIdentifier
getReductionType(const Fortran::parser::ProcedureDesignator &pd) {
- auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+ auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
getRealName(pd).ToString())
- .Case("max", IntrinsicProc::MAX)
- .Case("min", IntrinsicProc::MIN)
- .Case("iand", IntrinsicProc::IAND)
- .Case("ior", IntrinsicProc::IOR)
- .Case("ieor", IntrinsicProc::IEOR)
+ .Case("max", ReductionIdentifier::MAX)
+ .Case("min", ReductionIdentifier::MIN)
+ .Case("iand", ReductionIdentifier::IAND)
+ .Case("ior", ReductionIdentifier::IOR)
+ .Case("ieor", ReductionIdentifier::IEOR)
.Default(std::nullopt);
assert(redType && "Invalid Reduction");
return *redType;
}
+ static ReductionIdentifier getReductionType(
+ Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) {
+ switch (intrinsicOp) {
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ return ReductionIdentifier::ADD;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract:
+ return ReductionIdentifier::SUBTRACT;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ return ReductionIdentifier::MULTIPLY;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ return ReductionIdentifier::AND;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ return ReductionIdentifier::EQV;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ return ReductionIdentifier::OR;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ return ReductionIdentifier::NEQV;
+ default:
+ llvm_unreachable("unexpected intrinsic operator in reduction");
+ }
+ }
+
static bool supportedIntrinsicProcReduction(
const Fortran::parser::ProcedureDesignator &pd) {
const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
@@ -748,13 +786,13 @@ class ReductionProcessor {
if (!name->symbol->GetUltimate().attrs().test(
Fortran::semantics::Attr::INTRINSIC))
return false;
- auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+ auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
getRealName(name).ToString())
- .Case("max", IntrinsicProc::MAX)
- .Case("min", IntrinsicProc::MIN)
- .Case("iand", IntrinsicProc::IAND)
- .Case("ior", IntrinsicProc::IOR)
- .Case("ieor", IntrinsicProc::IEOR)
+ .Case("max", ReductionIdentifier::MAX)
+ .Case("min", ReductionIdentifier::MIN)
+ .Case("iand", ReductionIdentifier::IAND)
+ .Case("ior", ReductionIdentifier::IOR)
+ .Case("ieor", ReductionIdentifier::IEOR)
.Default(std::nullopt);
if (redType)
return true;
@@ -812,32 +850,30 @@ class ReductionProcessor {
/// reductionOpName. For example:
/// 0 + x = x,
/// 1 * x = x
- static int getOperationIdentity(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Location loc) {
- switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ static int getOperationIdentity(ReductionIdentifier redId,
+ mlir::Location loc) {
+ switch (redId) {
+ case ReductionIdentifier::ADD:
+ case ReductionIdentifier::OR:
+ case ReductionIdentifier::NEQV:
return 0;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ case ReductionIdentifier::MULTIPLY:
+ case ReductionIdentifier::AND:
+ case ReductionIdentifier::EQV:
return 1;
default:
TODO(loc, "Reduction of some intrinsic operators is not supported");
}
}
- static mlir::Value getIntrinsicProcInitValue(
- mlir::Location loc, mlir::Type type,
- const Fortran::parser::ProcedureDesignator &procDesignator,
- fir::FirOpBuilder &builder) {
+ static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
+ ReductionIdentifier redId,
+ fir::FirOpBuilder &builder) {
assert((fir::isa_integer(type) || fir::isa_real(type) ||
type.isa<fir::LogicalType>()) &&
"only integer, logical and real types are currently supported");
- switch (getReductionType(procDesignator)) {
- case IntrinsicProc::MAX: {
+ switch (redId) {
+ case ReductionIdentifier::MAX: {
if (auto ty = type.dyn_cast<mlir::FloatType>()) {
const llvm::fltSemantics &sem = ty.getFloatSemantics();
return builder.createRealConstant(
@@ -847,7 +883,7 @@ class ReductionProcessor {
int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, minInt);
}
- case IntrinsicProc::MIN: {
+ case ReductionIdentifier::MIN: {
if (auto ty = type.dyn_cast<mlir::FloatType>()) {
const llvm::fltSemantics &sem = ty.getFloatSemantics();
return builder.createRealConstant(
@@ -857,46 +893,50 @@ class ReductionProcessor {
int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, maxInt);
}
- case IntrinsicProc::IOR: {
+ case ReductionIdentifier::IOR: {
unsigned bits = type.getIntOrFloatBitWidth();
int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, zeroInt);
}
- case IntrinsicProc::IEOR: {
+ case ReductionIdentifier::IEOR: {
unsigned bits = type.getIntOrFloatBitWidth();
int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, zeroInt);
}
- case IntrinsicProc::IAND: {
+ case ReductionIdentifier::IAND: {
unsigned bits = type.getIntOrFloatBitWidth();
int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, allOnInt);
}
- }
- llvm_unreachable("Unknown Reduction Intrinsic");
- }
+ case ReductionIdentifier::ADD:
+ case ReductionIdentifier::MULTIPLY:
+ case ReductionIdentifier::AND:
+ case ReductionIdentifier::OR:
+ case ReductionIdentifier::EQV:
+ case ReductionIdentifier::NEQV:
+ if (type.isa<mlir::FloatType>())
+ return builder.create<mlir::arith::ConstantOp>(
+ loc, type,
+ builder.getFloatAttr(type,
+ (double)getOperationIdentity(redId, loc)));
+
+ if (type.isa<fir::LogicalType>()) {
+ mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
+ loc, builder.getI1Type(),
+ builder.getIntegerAttr(builder.getI1Type(),
+ getOperationIdentity(redId, loc)));
+ return builder.createConvert(loc, type, intConst);
+ }
- static mlir::Value getIntrinsicOpInitValue(
- mlir::Location loc, mlir::Type type,
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- fir::FirOpBuilder &builder) {
- if (type.isa<mlir::FloatType>())
return builder.create<mlir::arith::ConstantOp>(
loc, type,
- builder.getFloatAttr(type,
- (double)getOperationIdentity(intrinsicOp, loc)));
-
- if (type.isa<fir::LogicalType>()) {
- mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
- loc, builder.getI1Type(),
- builder.getIntegerAttr(builder.getI1Type(),
- getOperationIdentity(intrinsicOp, loc)));
- return builder.createConvert(loc, type, intConst);
+ builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
+ case ReductionIdentifier::ID:
+ case ReductionIdentifier::USER_DEF_OP:
+ case ReductionIdentifier::SUBTRACT:
+ TODO(loc, "Reduction of some identifier types is not supported");
}
-
- return builder.create<mlir::arith::ConstantOp>(
- loc, type,
- builder.getIntegerAttr(type, getOperationIdentity(intrinsicOp, loc)));
+ llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
}
template <typename FloatOp, typename IntegerOp>
@@ -910,118 +950,46 @@ class ReductionProcessor {
return builder.create<FloatOp>(loc, op1, op2);
}
- /// Creates an OpenMP reduction declaration and inserts it into the provided
- /// symbol table. The declaration has a constant initializer with the neutral
- /// value `initValue`, and the reduction combiner carried over from `reduce`.
- /// TODO: Generalize this for non-integer types, add atomic region.
- static mlir::omp::ReductionDeclareOp createReductionDecl(
- fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
- const Fortran::parser::ProcedureDesignator &procDesignator,
- mlir::Type type, mlir::Location loc) {
- mlir::OpBuilder::InsertionGuard guard(builder);
- mlir::ModuleOp module = builder.getModule();
-
- auto decl =
- module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
- if (decl)
- return decl;
-
- mlir::OpBuilder modBuilder(module.getBodyRegion());
-
- decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
- loc, reductionOpName, type);
- builder.createBlock(&decl.getInitializerRegion(),
- decl.getInitializerRegion().end(), {type}, {loc});
- builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
- mlir::Value init =
- getIntrinsicProcInitValue(loc, type, procDesignator, builder);
- builder.create<mlir::omp::YieldOp>(loc, init);
-
- builder.createBlock(&decl.getReductionRegion(),
- decl.getReductionRegion().end(), {type, type},
- {loc, loc});
-
- builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
- mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
- mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
-
+ static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder,
+ mlir::Location loc,
+ ReductionIdentifier redId,
+ mlir::Type type, mlir::Value op1,
+ mlir::Value op2) {
mlir::Value reductionOp;
- switch (getReductionType(procDesignator)) {
- case IntrinsicProc::MAX:
+ switch (redId) {
+ case ReductionIdentifier::MAX:
reductionOp =
getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
builder, type, loc, op1, op2);
break;
- case IntrinsicProc::MIN:
+ case ReductionIdentifier::MIN:
reductionOp =
getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
builder, type, loc, op1, op2);
break;
- case IntrinsicProc::IOR:
+ case ReductionIdentifier::IOR:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
break;
- case IntrinsicProc::IEOR:
+ case ReductionIdentifier::IEOR:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
break;
- case IntrinsicProc::IAND:
+ case ReductionIdentifier::IAND:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
break;
- }
-
- builder.create<mlir::omp::YieldOp>(loc, reductionOp);
- return decl;
- }
-
- /// Creates an OpenMP reduction declaration and inserts it into the provided
- /// symbol table. The declaration has a constant initializer with the neutral
- /// value `initValue`, and the reduction combiner carried over from `reduce`.
- /// TODO: Generalize this for non-integer types, add atomic region.
- static mlir::omp::ReductionDeclareOp createReductionDecl(
- fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Type type, mlir::Location loc) {
- mlir::OpBuilder::InsertionGuard guard(builder);
- mlir::ModuleOp module = builder.getModule();
-
- auto decl =
- module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
- if (decl)
- return decl;
-
- mlir::OpBuilder modBuilder(module.getBodyRegion());
-
- decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
- loc, reductionOpName, type);
- builder.createBlock(&decl.getInitializerRegion(),
- decl.getInitializerRegion().end(), {type}, {loc});
- builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
- mlir::Value init = getIntrinsicOpInitValue(loc, type, intrinsicOp, builder);
- builder.create<mlir::omp::YieldOp>(loc, init);
-
- builder.createBlock(&decl.getReductionRegion(),
- decl.getReductionRegion().end(), {type, type},
- {loc, loc});
-
- builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
- mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
- mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
-
- mlir::Value reductionOp;
- switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ case ReductionIdentifier::ADD:
reductionOp =
getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
builder, type, loc, op1, op2);
break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ case ReductionIdentifier::MULTIPLY:
reductionOp =
getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
builder, type, loc, op1, op2);
break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: {
+ case ReductionIdentifier::AND: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
@@ -1031,7 +999,7 @@ class ReductionProcessor {
reductionOp = builder.createConvert(loc, type, andiOp);
break;
}
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: {
+ case ReductionIdentifier::OR: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
@@ -1040,7 +1008,7 @@ class ReductionProcessor {
reductionOp = builder.createConvert(loc, type, oriOp);
break;
}
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: {
+ case ReductionIdentifier::EQV: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
@@ -1050,7 +1018,7 @@ class ReductionProcessor {
reductionOp = builder.createConvert(loc, type, cmpiOp);
break;
}
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: {
+ case ReductionIdentifier::NEQV: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
@@ -1064,7 +1032,46 @@ class ReductionProcessor {
TODO(loc, "Reduction of some intrinsic operators is not supported");
}
+ return reductionOp;
+ }
+
+ /// Creates an OpenMP reduction declaration and inserts it into the provided
+ /// symbol table. The declaration has a constant initializer with the neutral
+ /// value `initValue`, and the reduction combiner carried over from `reduce`.
+ /// TODO: Generalize this for non-integer types, add atomic region.
+ static mlir::omp::ReductionDeclareOp createReductionDecl(
+ fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+ const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) {
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ mlir::ModuleOp module = builder.getModule();
+
+ auto decl =
+ module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+ if (decl)
+ return decl;
+
+ mlir::OpBuilder modBuilder(module.getBodyRegion());
+
+ decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
+ loc, reductionOpName, type);
+ builder.createBlock(&decl.getInitializerRegion(),
+ decl.getInitializerRegion().end(), {type}, {loc});
+ builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+ mlir::Value init = getReductionInitValue(loc, type, redId, builder);
+ builder.create<mlir::omp::YieldOp>(loc, init);
+
+ builder.createBlock(&decl.getReductionRegion(),
+ decl.getReductionRegion().end(), {type, type},
+ {loc, loc});
+
+ builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+ mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+ mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+
+ mlir::Value reductionOp =
+ createScalarCombiner(builder, loc, redId, type, op1, op2);
builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+
return decl;
}
@@ -1087,15 +1094,15 @@ class ReductionProcessor {
const auto &intrinsicOp{
std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
redDefinedOp->u)};
- switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ ReductionIdentifier redId = getReductionType(intrinsicOp);
+ switch (redId) {
+ case ReductionIdentifier::ADD:
+ case ReductionIdentifier::MULTIPLY:
+ case ReductionIdentifier::AND:
+ case ReductionIdentifier::EQV:
+ case ReductionIdentifier::OR:
+ case ReductionIdentifier::NEQV:
break;
-
default:
TODO(currentLocation,
"Reduction of some intrinsic operators is not supported");
@@ -1115,11 +1122,11 @@ class ReductionProcessor {
decl = createReductionDecl(
firOpBuilder,
getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
- intrinsicOp, redType, currentLocation);
+ redId, redType, currentLocation);
else if (redType.isIntOrIndexOrFloat()) {
decl = createReductionDecl(firOpBuilder,
...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/79876
More information about the flang-commits
mailing list