[flang-commits] [flang] [Flang][OpenMP] NFC: Minor refactoring of Reduction lowering code (PR #70790)
Kiran Chandramohan via flang-commits
flang-commits at lists.llvm.org
Wed Nov 22 06:06:05 PST 2023
https://github.com/kiranchandramohan updated https://github.com/llvm/llvm-project/pull/70790
>From d0f6d3eef144ec25278f50b0d16f03928dd03377 Mon Sep 17 00:00:00 2001
From: Kiran Chandramohan <kiran.chandramohan at arm.com>
Date: Tue, 31 Oct 2023 10:43:30 +0000
Subject: [PATCH] [Flang][OpenMP] NFC: Minor refactoring of Reduction lowering
code
Move reduction lowering code into a ReductionProcessor class.
Create an enumeration for Intrinsic Procedure reductions.
---
flang/lib/Lower/OpenMP.cpp | 793 ++++++++++++++++++++-----------------
1 file changed, 433 insertions(+), 360 deletions(-)
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index f6a61ba3a528e32..6267231d7fbe253 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -685,276 +685,441 @@ static void checkMapType(mlir::Location location, mlir::Type type) {
TODO(location, "OMPD_target_data MapOperand BoxType");
}
-static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
- return (llvm::Twine(name) +
- (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
- llvm::Twine(ty.getIntOrFloatBitWidth()))
- .str();
-}
+class ReductionProcessor {
+public:
+ enum IntrinsicProc { MAX, MIN, IAND, IOR, IEOR };
+ static IntrinsicProc
+ getReductionType(const Fortran::parser::ProcedureDesignator &pd) {
+ auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+ getRealName(pd).ToString())
+ .Case("max", IntrinsicProc::MAX)
+ .Case("min", IntrinsicProc::MIN)
+ .Case("iand", IntrinsicProc::IAND)
+ .Case("ior", IntrinsicProc::IOR)
+ .Case("ieor", IntrinsicProc::IEOR)
+ .Default(std::nullopt);
+ assert(redType && "Invalid Reduction");
+ return *redType;
+ }
+
+ static bool supportedIntrinsicProcReduction(
+ const Fortran::parser::ProcedureDesignator &pd) {
+ const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
+ assert(name && "Invalid Reduction Intrinsic.");
+ auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+ getRealName(name).ToString())
+ .Case("max", IntrinsicProc::MAX)
+ .Case("min", IntrinsicProc::MIN)
+ .Case("iand", IntrinsicProc::IAND)
+ .Case("ior", IntrinsicProc::IOR)
+ .Case("ieor", IntrinsicProc::IEOR)
+ .Default(std::nullopt);
+ if (redType)
+ return true;
+ return false;
+ }
-static std::string getReductionName(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Type ty) {
- std::string reductionName;
+ static const Fortran::semantics::SourceName
+ getRealName(const Fortran::parser::Name *name) {
+ return name->symbol->GetUltimate().name();
+ }
- switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
- reductionName = "add_reduction";
- break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
- reductionName = "multiply_reduction";
- break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
- return "and_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
- return "eqv_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
- return "or_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
- return "neqv_reduction";
- default:
- reductionName = "other_reduction";
- break;
+ static const Fortran::semantics::SourceName
+ getRealName(const Fortran::parser::ProcedureDesignator &pd) {
+ const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
+ assert(name && "Invalid Reduction Intrinsic.");
+ return getRealName(name);
}
- return getReductionName(reductionName, ty);
-}
-
-/// This function returns the identity value of the operator \p reductionOpName.
-/// For example:
-/// 0 + x = x,
-/// 1 * x = x
-static int getOperationIdentity(llvm::StringRef reductionOpName,
- mlir::Location loc) {
- if (reductionOpName.contains("add") || reductionOpName.contains("or") ||
- reductionOpName.contains("neqv"))
- return 0;
- if (reductionOpName.contains("multiply") || reductionOpName.contains("and") ||
- reductionOpName.contains("eqv"))
- return 1;
- TODO(loc, "Reduction of some intrinsic operators is not supported");
-}
-
-static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
- llvm::StringRef reductionOpName,
- fir::FirOpBuilder &builder) {
- assert((fir::isa_integer(type) || fir::isa_real(type) ||
- type.isa<fir::LogicalType>()) &&
- "only integer, logical and real types are currently supported");
- if (reductionOpName.contains("max")) {
- if (auto ty = type.dyn_cast<mlir::FloatType>()) {
- const llvm::fltSemantics &sem = ty.getFloatSemantics();
- return builder.createRealConstant(
- loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
- }
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, minInt);
- } else if (reductionOpName.contains("min")) {
- if (auto ty = type.dyn_cast<mlir::FloatType>()) {
- const llvm::fltSemantics &sem = ty.getFloatSemantics();
- return builder.createRealConstant(
- loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
- }
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, maxInt);
- } else if (reductionOpName.contains("ior")) {
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, zeroInt);
- } else if (reductionOpName.contains("ieor")) {
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, zeroInt);
- } else if (reductionOpName.contains("iand")) {
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, allOnInt);
- } else {
+ static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
+ return (llvm::Twine(name) +
+ (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
+ llvm::Twine(ty.getIntOrFloatBitWidth()))
+ .str();
+ }
+
+ static std::string getReductionName(
+ Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+ mlir::Type ty) {
+ std::string reductionName;
+
+ switch (intrinsicOp) {
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ reductionName = "add_reduction";
+ break;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ reductionName = "multiply_reduction";
+ break;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ return "and_reduction";
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ return "eqv_reduction";
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ return "or_reduction";
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ return "neqv_reduction";
+ default:
+ reductionName = "other_reduction";
+ break;
+ }
+
+ return getReductionName(reductionName, ty);
+ }
+
+ /// This function returns the identity value of the operator \p
+ /// reductionOpName. For example:
+ /// 0 + x = x,
+ /// 1 * x = x
+ static int getOperationIdentity(
+ Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+ mlir::Location loc) {
+ switch (intrinsicOp) {
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ return 0;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ return 1;
+ default:
+ TODO(loc, "Reduction of some intrinsic operators is not supported");
+ }
+ }
+
+ static mlir::Value getIntrinsicProcInitValue(
+ mlir::Location loc, mlir::Type type,
+ const Fortran::parser::ProcedureDesignator &procDesignator,
+ fir::FirOpBuilder &builder) {
+ assert((fir::isa_integer(type) || fir::isa_real(type) ||
+ type.isa<fir::LogicalType>()) &&
+ "only integer, logical and real types are currently supported");
+ switch (getReductionType(procDesignator)) {
+ case IntrinsicProc::MAX: {
+ if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+ const llvm::fltSemantics &sem = ty.getFloatSemantics();
+ return builder.createRealConstant(
+ loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
+ }
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, minInt);
+ }
+ case IntrinsicProc::MIN: {
+ if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+ const llvm::fltSemantics &sem = ty.getFloatSemantics();
+ return builder.createRealConstant(
+ loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
+ }
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, maxInt);
+ }
+ case IntrinsicProc::IOR: {
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, zeroInt);
+ }
+ case IntrinsicProc::IEOR: {
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, zeroInt);
+ }
+ case IntrinsicProc::IAND: {
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, allOnInt);
+ }
+ }
+ llvm_unreachable("Unknown Reduction Intrinsic");
+ }
+
+ static mlir::Value getIntrinsicOpInitValue(
+ mlir::Location loc, mlir::Type type,
+ Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+ fir::FirOpBuilder &builder) {
if (type.isa<mlir::FloatType>())
return builder.create<mlir::arith::ConstantOp>(
loc, type,
- builder.getFloatAttr(
- type, (double)getOperationIdentity(reductionOpName, loc)));
+ builder.getFloatAttr(type,
+ (double)getOperationIdentity(intrinsicOp, loc)));
if (type.isa<fir::LogicalType>()) {
mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
loc, builder.getI1Type(),
builder.getIntegerAttr(builder.getI1Type(),
- getOperationIdentity(reductionOpName, loc)));
+ getOperationIdentity(intrinsicOp, loc)));
return builder.createConvert(loc, type, intConst);
}
return builder.create<mlir::arith::ConstantOp>(
loc, type,
- builder.getIntegerAttr(type,
- getOperationIdentity(reductionOpName, loc)));
- }
-}
-
-template <typename FloatOp, typename IntegerOp>
-static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
- mlir::Type type, mlir::Location loc,
- mlir::Value op1, mlir::Value op2) {
- assert(type.isIntOrIndexOrFloat() &&
- "only integer and float types are currently supported");
- if (type.isIntOrIndex())
- return builder.create<IntegerOp>(loc, op1, op2);
- return builder.create<FloatOp>(loc, op1, op2);
-}
-
-static mlir::omp::ReductionDeclareOp
-createMinimalReductionDecl(fir::FirOpBuilder &builder,
- llvm::StringRef reductionOpName, mlir::Type type,
- mlir::Location loc) {
- mlir::ModuleOp module = builder.getModule();
- mlir::OpBuilder modBuilder(module.getBodyRegion());
-
- mlir::omp::ReductionDeclareOp decl =
- modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
- type);
- builder.createBlock(&decl.getInitializerRegion(),
- decl.getInitializerRegion().end(), {type}, {loc});
- builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
- mlir::Value init = getReductionInitValue(loc, type, reductionOpName, builder);
- builder.create<mlir::omp::YieldOp>(loc, init);
-
- builder.createBlock(&decl.getReductionRegion(),
- decl.getReductionRegion().end(), {type, type},
- {loc, loc});
-
- return decl;
-}
-
-/// Creates an OpenMP reduction declaration and inserts it into the provided
-/// symbol table. The declaration has a constant initializer with the neutral
-/// value `initValue`, and the reduction combiner carried over from `reduce`.
-/// TODO: Generalize this for non-integer types, add atomic region.
-static mlir::omp::ReductionDeclareOp
-createReductionDecl(fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
- const Fortran::parser::ProcedureDesignator &procDesignator,
- mlir::Type type, mlir::Location loc) {
- mlir::OpBuilder::InsertionGuard guard(builder);
- mlir::ModuleOp module = builder.getModule();
-
- auto decl =
- module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
- if (decl)
- return decl;
+ builder.getIntegerAttr(type, getOperationIdentity(intrinsicOp, loc)));
+ }
+
+ template <typename FloatOp, typename IntegerOp>
+ static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
+ mlir::Type type, mlir::Location loc,
+ mlir::Value op1, mlir::Value op2) {
+ assert(type.isIntOrIndexOrFloat() &&
+ "only integer and float types are currently supported");
+ if (type.isIntOrIndex())
+ return builder.create<IntegerOp>(loc, op1, op2);
+ return builder.create<FloatOp>(loc, op1, op2);
+ }
+
+ /// Creates an OpenMP reduction declaration and inserts it into the provided
+ /// symbol table. The declaration has a constant initializer with the neutral
+ /// value `initValue`, and the reduction combiner carried over from `reduce`.
+ /// TODO: Generalize this for non-integer types, add atomic region.
+ static mlir::omp::ReductionDeclareOp createReductionDecl(
+ fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+ const Fortran::parser::ProcedureDesignator &procDesignator,
+ mlir::Type type, mlir::Location loc) {
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ mlir::ModuleOp module = builder.getModule();
+
+ auto decl =
+ module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+ if (decl)
+ return decl;
- decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
- builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
- mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
- mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+ mlir::OpBuilder modBuilder(module.getBodyRegion());
- mlir::Value reductionOp;
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(procDesignator)}) {
- if (name->source == "max") {
+ decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
+ loc, reductionOpName, type);
+ builder.createBlock(&decl.getInitializerRegion(),
+ decl.getInitializerRegion().end(), {type}, {loc});
+ builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+ mlir::Value init =
+ getIntrinsicProcInitValue(loc, type, procDesignator, builder);
+ builder.create<mlir::omp::YieldOp>(loc, init);
+
+ builder.createBlock(&decl.getReductionRegion(),
+ decl.getReductionRegion().end(), {type, type},
+ {loc, loc});
+
+ builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+ mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+ mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+
+ mlir::Value reductionOp;
+ switch (getReductionType(procDesignator)) {
+ case IntrinsicProc::MAX:
reductionOp =
getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
builder, type, loc, op1, op2);
- } else if (name->source == "min") {
+ break;
+ case IntrinsicProc::MIN:
reductionOp =
getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
builder, type, loc, op1, op2);
- } else if (name->source == "ior") {
+ break;
+ case IntrinsicProc::IOR:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
- } else if (name->source == "ieor") {
+ break;
+ case IntrinsicProc::IEOR:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
- } else if (name->source == "iand") {
+ break;
+ case IntrinsicProc::IAND:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
- } else {
- TODO(loc, "Reduction of some intrinsic operators is not supported");
+ break;
+ default:
+ llvm_unreachable(
+ "Reduction of some intrinsic operators is not supported");
}
+
+ builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+ return decl;
}
- builder.create<mlir::omp::YieldOp>(loc, reductionOp);
- return decl;
-}
+ /// Creates an OpenMP reduction declaration and inserts it into the provided
+ /// symbol table. The declaration has a constant initializer with the neutral
+ /// value `initValue`, and the reduction combiner carried over from `reduce`.
+ /// TODO: Generalize this for non-integer types, add atomic region.
+ static mlir::omp::ReductionDeclareOp createReductionDecl(
+ fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+ Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+ mlir::Type type, mlir::Location loc) {
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ mlir::ModuleOp module = builder.getModule();
-/// Creates an OpenMP reduction declaration and inserts it into the provided
-/// symbol table. The declaration has a constant initializer with the neutral
-/// value `initValue`, and the reduction combiner carried over from `reduce`.
-/// TODO: Generalize this for non-integer types, add atomic region.
-static mlir::omp::ReductionDeclareOp createReductionDecl(
- fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Type type, mlir::Location loc) {
- mlir::OpBuilder::InsertionGuard guard(builder);
- mlir::ModuleOp module = builder.getModule();
+ auto decl =
+ module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+ if (decl)
+ return decl;
- auto decl =
- module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
- if (decl)
- return decl;
+ mlir::OpBuilder modBuilder(module.getBodyRegion());
- decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
- builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
- mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
- mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+ decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
+ loc, reductionOpName, type);
+ builder.createBlock(&decl.getInitializerRegion(),
+ decl.getInitializerRegion().end(), {type}, {loc});
+ builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+ mlir::Value init = getIntrinsicOpInitValue(loc, type, intrinsicOp, builder);
+ builder.create<mlir::omp::YieldOp>(loc, init);
- mlir::Value reductionOp;
- switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
- reductionOp =
- getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
- builder, type, loc, op1, op2);
- break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
- reductionOp =
- getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
- builder, type, loc, op1, op2);
- break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: {
- mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
- mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+ builder.createBlock(&decl.getReductionRegion(),
+ decl.getReductionRegion().end(), {type, type},
+ {loc, loc});
- mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
+ builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+ mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+ mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
- reductionOp = builder.createConvert(loc, type, andiOp);
- break;
- }
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: {
- mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
- mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+ mlir::Value reductionOp;
+ switch (intrinsicOp) {
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ reductionOp =
+ getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
+ builder, type, loc, op1, op2);
+ break;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ reductionOp =
+ getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
+ builder, type, loc, op1, op2);
+ break;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: {
+ mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+ mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
- mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
+ mlir::Value andiOp =
+ builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
- reductionOp = builder.createConvert(loc, type, oriOp);
- break;
- }
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: {
- mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
- mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+ reductionOp = builder.createConvert(loc, type, andiOp);
+ break;
+ }
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: {
+ mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+ mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
- mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
+ mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
- reductionOp = builder.createConvert(loc, type, cmpiOp);
- break;
- }
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: {
- mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
- mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+ reductionOp = builder.createConvert(loc, type, oriOp);
+ break;
+ }
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: {
+ mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+ mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
- mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
+ mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
- reductionOp = builder.createConvert(loc, type, cmpiOp);
- break;
- }
- default:
- TODO(loc, "Reduction of some intrinsic operators is not supported");
+ reductionOp = builder.createConvert(loc, type, cmpiOp);
+ break;
+ }
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: {
+ mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+ mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+
+ mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
+
+ reductionOp = builder.createConvert(loc, type, cmpiOp);
+ break;
+ }
+ default:
+ TODO(loc, "Reduction of some intrinsic operators is not supported");
+ }
+
+ builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+ return decl;
}
- builder.create<mlir::omp::YieldOp>(loc, reductionOp);
- return decl;
-}
+ /// Creates a reduction declaration and associates it with an OpenMP block
+ /// directive.
+ static void addReductionDecl(
+ mlir::Location currentLocation,
+ Fortran::lower::AbstractConverter &converter,
+ const Fortran::parser::OmpReductionClause &reduction,
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ mlir::omp::ReductionDeclareOp decl;
+ const auto &redOperator{
+ std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
+ const auto &objectList{
+ std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+ if (const auto &redDefinedOp =
+ std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
+ const auto &intrinsicOp{
+ std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
+ redDefinedOp->u)};
+ switch (intrinsicOp) {
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ break;
+
+ default:
+ TODO(currentLocation,
+ "Reduction of some intrinsic operators is not supported");
+ break;
+ }
+ for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+ if (const auto *name{
+ Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+ if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+ symVal = declOp.getBase();
+ mlir::Type redType =
+ symVal.getType().cast<fir::ReferenceType>().getEleTy();
+ reductionVars.push_back(symVal);
+ if (redType.isa<fir::LogicalType>())
+ decl = createReductionDecl(
+ firOpBuilder,
+ getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
+ intrinsicOp, redType, currentLocation);
+ else if (redType.isIntOrIndexOrFloat()) {
+ decl = createReductionDecl(firOpBuilder,
+ getReductionName(intrinsicOp, redType),
+ intrinsicOp, redType, currentLocation);
+ } else {
+ TODO(currentLocation, "Reduction of some types is not supported");
+ }
+ reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+ firOpBuilder.getContext(), decl.getSymName()));
+ }
+ }
+ }
+ } else if (const auto *reductionIntrinsic =
+ std::get_if<Fortran::parser::ProcedureDesignator>(
+ &redOperator.u)) {
+ if (ReductionProcessor::supportedIntrinsicProcReduction(
+ *reductionIntrinsic)) {
+ for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+ if (const auto *name{
+ Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+ if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+ symVal = declOp.getBase();
+ mlir::Type redType =
+ symVal.getType().cast<fir::ReferenceType>().getEleTy();
+ reductionVars.push_back(symVal);
+ assert(redType.isIntOrIndexOrFloat() &&
+ "Unsupported reduction type");
+ decl = createReductionDecl(
+ firOpBuilder,
+ getReductionName(getRealName(*reductionIntrinsic).ToString(),
+ redType),
+ *reductionIntrinsic, redType, currentLocation);
+ reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+ firOpBuilder.getContext(), decl.getSymName()));
+ }
+ }
+ }
+ }
+ }
+ }
+};
static mlir::omp::ScheduleModifier
translateScheduleModifier(const Fortran::parser::OmpScheduleModifierType &m) {
@@ -1137,101 +1302,6 @@ static mlir::Value getIfClauseOperand(
ifVal);
}
-/// Creates a reduction declaration and associates it with an OpenMP block
-/// directive.
-static void
-addReductionDecl(mlir::Location currentLocation,
- Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpReductionClause &reduction,
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- mlir::omp::ReductionDeclareOp decl;
- const auto &redOperator{
- std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
- const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
- if (const auto &redDefinedOp =
- std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
- const auto &intrinsicOp{
- std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
- redDefinedOp->u)};
- switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
- break;
-
- default:
- TODO(currentLocation,
- "Reduction of some intrinsic operators is not supported");
- break;
- }
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- mlir::Value symVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
- symVal = declOp.getBase();
- mlir::Type redType =
- symVal.getType().cast<fir::ReferenceType>().getEleTy();
- reductionVars.push_back(symVal);
- if (redType.isa<fir::LogicalType>())
- decl = createReductionDecl(
- firOpBuilder,
- getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
- intrinsicOp, redType, currentLocation);
- else if (redType.isIntOrIndexOrFloat()) {
- decl = createReductionDecl(firOpBuilder,
- getReductionName(intrinsicOp, redType),
- intrinsicOp, redType, currentLocation);
- } else {
- TODO(currentLocation, "Reduction of some types is not supported");
- }
- reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
- firOpBuilder.getContext(), decl.getSymName()));
- }
- }
- }
- } else if (const auto *reductionIntrinsic =
- std::get_if<Fortran::parser::ProcedureDesignator>(
- &redOperator.u)) {
- if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
- reductionIntrinsic)}) {
- if ((name->source != "max") && (name->source != "min") &&
- (name->source != "ior") && (name->source != "ieor") &&
- (name->source != "iand")) {
- TODO(currentLocation,
- "Reduction of intrinsic procedures is not supported");
- }
- std::string intrinsicOp = name->ToString();
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- mlir::Value symVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
- symVal = declOp.getBase();
- mlir::Type redType =
- symVal.getType().cast<fir::ReferenceType>().getEleTy();
- reductionVars.push_back(symVal);
- assert(redType.isIntOrIndexOrFloat() &&
- "Unsupported reduction type");
- decl = createReductionDecl(
- firOpBuilder, getReductionName(intrinsicOp, redType),
- *reductionIntrinsic, redType, currentLocation);
- reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
- firOpBuilder.getContext(), decl.getSymName()));
- }
- }
- }
- }
- }
-}
-
static void
addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpObjectList &useDeviceClause,
@@ -1828,8 +1898,9 @@ bool ClauseProcessor::processReduction(
return findRepeatableClause<ClauseTy::Reduction>(
[&](const ClauseTy::Reduction *reductionClause,
const Fortran::parser::CharBlock &) {
- addReductionDecl(currentLocation, converter, reductionClause->v,
- reductionVars, reductionDeclSymbols);
+ ReductionProcessor rp;
+ rp.addReductionDecl(currentLocation, converter, reductionClause->v,
+ reductionVars, reductionDeclSymbols);
});
}
@@ -3665,48 +3736,50 @@ void Fortran::lower::genOpenMPReduction(
} else if (const auto *reductionIntrinsic =
std::get_if<Fortran::parser::ProcedureDesignator>(
&redOperator.u)) {
- if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
- reductionIntrinsic)}) {
- std::string redName = name->ToString();
- if ((name->source != "max") && (name->source != "min") &&
- (name->source != "ior") && (name->source != "ieor") &&
- (name->source != "iand")) {
- continue;
- }
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
- ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
- if (auto declOp =
- reductionVal.getDefiningOp<hlfir::DeclareOp>())
- reductionVal = declOp.getBase();
- for (const mlir::OpOperand &reductionValUse :
- reductionVal.getUses()) {
- if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
- reductionValUse.getOwner())) {
- mlir::Value loadVal = loadOp.getRes();
- // Max is lowered as a compare -> select.
- // Match the pattern here.
- mlir::Operation *reductionOp =
- findReductionChain(loadVal, &reductionVal);
- if (reductionOp == nullptr)
- continue;
-
- if (redName == "max" || redName == "min") {
- assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
- "Selection Op not found in reduction intrinsic");
- mlir::Operation *compareOp =
- getCompareFromReductionOp(reductionOp, loadVal);
- updateReduction(compareOp, firOpBuilder, loadVal,
- reductionVal);
- }
- if (redName == "ior" || redName == "ieor" ||
- redName == "iand") {
+ if (!ReductionProcessor::supportedIntrinsicProcReduction(
+ *reductionIntrinsic))
+ continue;
+ ReductionProcessor::IntrinsicProc redIntrinsicProc =
+ ReductionProcessor::getReductionType(*reductionIntrinsic);
+ for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+ if (const auto *name{
+ Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+ if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
+ reductionVal = declOp.getBase();
+ for (const mlir::OpOperand &reductionValUse :
+ reductionVal.getUses()) {
+ if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
+ reductionValUse.getOwner())) {
+ mlir::Value loadVal = loadOp.getRes();
+ // Max is lowered as a compare -> select.
+ // Match the pattern here.
+ mlir::Operation *reductionOp =
+ findReductionChain(loadVal, &reductionVal);
+ if (reductionOp == nullptr)
+ continue;
+
+ if (redIntrinsicProc ==
+ ReductionProcessor::IntrinsicProc::MAX ||
+ redIntrinsicProc ==
+ ReductionProcessor::IntrinsicProc::MIN) {
+ assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
+ "Selection Op not found in reduction intrinsic");
+ mlir::Operation *compareOp =
+ getCompareFromReductionOp(reductionOp, loadVal);
+ updateReduction(compareOp, firOpBuilder, loadVal,
+ reductionVal);
+ }
+ if (redIntrinsicProc ==
+ ReductionProcessor::IntrinsicProc::IOR ||
+ redIntrinsicProc ==
+ ReductionProcessor::IntrinsicProc::IEOR ||
+ redIntrinsicProc ==
+ ReductionProcessor::IntrinsicProc::IAND) {
- updateReduction(reductionOp, firOpBuilder, loadVal,
- reductionVal);
- }
+ updateReduction(reductionOp, firOpBuilder, loadVal,
+ reductionVal);
}
}
}
More information about the flang-commits
mailing list