[flang-commits] [flang] c2b3f16 - Revert "[Flang][OpenMP] NFC: Minor refactoring of Reduction lowering code" (#73139)
via flang-commits
flang-commits at lists.llvm.org
Wed Nov 22 07:47:28 PST 2023
Author: Kiran Chandramohan
Date: 2023-11-22T15:47:24Z
New Revision: c2b3f16fb595fa88bfd21b455785c59ac6a21ed4
URL: https://github.com/llvm/llvm-project/commit/c2b3f16fb595fa88bfd21b455785c59ac6a21ed4
DIFF: https://github.com/llvm/llvm-project/commit/c2b3f16fb595fa88bfd21b455785c59ac6a21ed4.diff
LOG: Revert "[Flang][OpenMP] NFC: Minor refactoring of Reduction lowering code" (#73139)
Reverts llvm/llvm-project#70790 to fix CI failure
(https://lab.llvm.org/buildbot/#/builders/268/builds/2884)
Added:
Modified:
flang/lib/Lower/OpenMP.cpp
Removed:
################################################################################
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 6267231d7fbe253..f6a61ba3a528e32 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -685,441 +685,276 @@ static void checkMapType(mlir::Location location, mlir::Type type) {
TODO(location, "OMPD_target_data MapOperand BoxType");
}
-class ReductionProcessor {
-public:
- enum IntrinsicProc { MAX, MIN, IAND, IOR, IEOR };
- static IntrinsicProc
- getReductionType(const Fortran::parser::ProcedureDesignator &pd) {
- auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
- getRealName(pd).ToString())
- .Case("max", IntrinsicProc::MAX)
- .Case("min", IntrinsicProc::MIN)
- .Case("iand", IntrinsicProc::IAND)
- .Case("ior", IntrinsicProc::IOR)
- .Case("ieor", IntrinsicProc::IEOR)
- .Default(std::nullopt);
- assert(redType && "Invalid Reduction");
- return *redType;
- }
-
- static bool supportedIntrinsicProcReduction(
- const Fortran::parser::ProcedureDesignator &pd) {
- const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
- assert(name && "Invalid Reduction Intrinsic.");
- auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
- getRealName(name).ToString())
- .Case("max", IntrinsicProc::MAX)
- .Case("min", IntrinsicProc::MIN)
- .Case("iand", IntrinsicProc::IAND)
- .Case("ior", IntrinsicProc::IOR)
- .Case("ieor", IntrinsicProc::IEOR)
- .Default(std::nullopt);
- if (redType)
- return true;
- return false;
- }
-
- static const Fortran::semantics::SourceName
- getRealName(const Fortran::parser::Name *name) {
- return name->symbol->GetUltimate().name();
- }
-
- static const Fortran::semantics::SourceName
- getRealName(const Fortran::parser::ProcedureDesignator &pd) {
- const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
- assert(name && "Invalid Reduction Intrinsic.");
- return getRealName(name);
- }
-
- static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
- return (llvm::Twine(name) +
- (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
- llvm::Twine(ty.getIntOrFloatBitWidth()))
- .str();
- }
-
- static std::string getReductionName(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Type ty) {
- std::string reductionName;
-
- switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
- reductionName = "add_reduction";
- break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
- reductionName = "multiply_reduction";
- break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
- return "and_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
- return "eqv_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
- return "or_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
- return "neqv_reduction";
- default:
- reductionName = "other_reduction";
- break;
- }
-
- return getReductionName(reductionName, ty);
- }
+static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
+ return (llvm::Twine(name) +
+ (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
+ llvm::Twine(ty.getIntOrFloatBitWidth()))
+ .str();
+}
- /// This function returns the identity value of the operator \p
- /// reductionOpName. For example:
- /// 0 + x = x,
- /// 1 * x = x
- static int getOperationIdentity(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Location loc) {
- switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
- return 0;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
- return 1;
- default:
- TODO(loc, "Reduction of some intrinsic operators is not supported");
- }
- }
+static std::string getReductionName(
+ Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+ mlir::Type ty) {
+ std::string reductionName;
- static mlir::Value getIntrinsicProcInitValue(
- mlir::Location loc, mlir::Type type,
- const Fortran::parser::ProcedureDesignator &procDesignator,
- fir::FirOpBuilder &builder) {
- assert((fir::isa_integer(type) || fir::isa_real(type) ||
- type.isa<fir::LogicalType>()) &&
- "only integer, logical and real types are currently supported");
- switch (getReductionType(procDesignator)) {
- case IntrinsicProc::MAX: {
- if (auto ty = type.dyn_cast<mlir::FloatType>()) {
- const llvm::fltSemantics &sem = ty.getFloatSemantics();
- return builder.createRealConstant(
- loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
- }
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, minInt);
- }
- case IntrinsicProc::MIN: {
- if (auto ty = type.dyn_cast<mlir::FloatType>()) {
- const llvm::fltSemantics &sem = ty.getFloatSemantics();
- return builder.createRealConstant(
- loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
- }
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, maxInt);
- }
- case IntrinsicProc::IOR: {
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, zeroInt);
- }
- case IntrinsicProc::IEOR: {
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, zeroInt);
- }
- case IntrinsicProc::IAND: {
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, allOnInt);
- }
- }
- llvm_unreachable("Unknown Reduction Intrinsic");
+ switch (intrinsicOp) {
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ reductionName = "add_reduction";
+ break;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ reductionName = "multiply_reduction";
+ break;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ return "and_reduction";
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ return "eqv_reduction";
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ return "or_reduction";
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ return "neqv_reduction";
+ default:
+ reductionName = "other_reduction";
+ break;
}
- static mlir::Value getIntrinsicOpInitValue(
- mlir::Location loc, mlir::Type type,
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- fir::FirOpBuilder &builder) {
+ return getReductionName(reductionName, ty);
+}
+
+/// This function returns the identity value of the operator \p reductionOpName.
+/// For example:
+/// 0 + x = x,
+/// 1 * x = x
+static int getOperationIdentity(llvm::StringRef reductionOpName,
+ mlir::Location loc) {
+ if (reductionOpName.contains("add") || reductionOpName.contains("or") ||
+ reductionOpName.contains("neqv"))
+ return 0;
+ if (reductionOpName.contains("multiply") || reductionOpName.contains("and") ||
+ reductionOpName.contains("eqv"))
+ return 1;
+ TODO(loc, "Reduction of some intrinsic operators is not supported");
+}
+
+static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
+ llvm::StringRef reductionOpName,
+ fir::FirOpBuilder &builder) {
+ assert((fir::isa_integer(type) || fir::isa_real(type) ||
+ type.isa<fir::LogicalType>()) &&
+ "only integer, logical and real types are currently supported");
+ if (reductionOpName.contains("max")) {
+ if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+ const llvm::fltSemantics &sem = ty.getFloatSemantics();
+ return builder.createRealConstant(
+ loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
+ }
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, minInt);
+ } else if (reductionOpName.contains("min")) {
+ if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+ const llvm::fltSemantics &sem = ty.getFloatSemantics();
+ return builder.createRealConstant(
+ loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
+ }
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, maxInt);
+ } else if (reductionOpName.contains("ior")) {
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, zeroInt);
+ } else if (reductionOpName.contains("ieor")) {
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, zeroInt);
+ } else if (reductionOpName.contains("iand")) {
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, allOnInt);
+ } else {
if (type.isa<mlir::FloatType>())
return builder.create<mlir::arith::ConstantOp>(
loc, type,
- builder.getFloatAttr(type,
- (double)getOperationIdentity(intrinsicOp, loc)));
+ builder.getFloatAttr(
+ type, (double)getOperationIdentity(reductionOpName, loc)));
if (type.isa<fir::LogicalType>()) {
mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
loc, builder.getI1Type(),
builder.getIntegerAttr(builder.getI1Type(),
- getOperationIdentity(intrinsicOp, loc)));
+ getOperationIdentity(reductionOpName, loc)));
return builder.createConvert(loc, type, intConst);
}
return builder.create<mlir::arith::ConstantOp>(
loc, type,
- builder.getIntegerAttr(type, getOperationIdentity(intrinsicOp, loc)));
- }
-
- template <typename FloatOp, typename IntegerOp>
- static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
- mlir::Type type, mlir::Location loc,
- mlir::Value op1, mlir::Value op2) {
- assert(type.isIntOrIndexOrFloat() &&
- "only integer and float types are currently supported");
- if (type.isIntOrIndex())
- return builder.create<IntegerOp>(loc, op1, op2);
- return builder.create<FloatOp>(loc, op1, op2);
- }
-
- /// Creates an OpenMP reduction declaration and inserts it into the provided
- /// symbol table. The declaration has a constant initializer with the neutral
- /// value `initValue`, and the reduction combiner carried over from `reduce`.
- /// TODO: Generalize this for non-integer types, add atomic region.
- static mlir::omp::ReductionDeclareOp createReductionDecl(
- fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
- const Fortran::parser::ProcedureDesignator &procDesignator,
- mlir::Type type, mlir::Location loc) {
- mlir::OpBuilder::InsertionGuard guard(builder);
- mlir::ModuleOp module = builder.getModule();
-
- auto decl =
- module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
- if (decl)
- return decl;
+ builder.getIntegerAttr(type,
+ getOperationIdentity(reductionOpName, loc)));
+ }
+}
+
+template <typename FloatOp, typename IntegerOp>
+static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
+ mlir::Type type, mlir::Location loc,
+ mlir::Value op1, mlir::Value op2) {
+ assert(type.isIntOrIndexOrFloat() &&
+ "only integer and float types are currently supported");
+ if (type.isIntOrIndex())
+ return builder.create<IntegerOp>(loc, op1, op2);
+ return builder.create<FloatOp>(loc, op1, op2);
+}
+
+static mlir::omp::ReductionDeclareOp
+createMinimalReductionDecl(fir::FirOpBuilder &builder,
+ llvm::StringRef reductionOpName, mlir::Type type,
+ mlir::Location loc) {
+ mlir::ModuleOp module = builder.getModule();
+ mlir::OpBuilder modBuilder(module.getBodyRegion());
+
+ mlir::omp::ReductionDeclareOp decl =
+ modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
+ type);
+ builder.createBlock(&decl.getInitializerRegion(),
+ decl.getInitializerRegion().end(), {type}, {loc});
+ builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+ mlir::Value init = getReductionInitValue(loc, type, reductionOpName, builder);
+ builder.create<mlir::omp::YieldOp>(loc, init);
+
+ builder.createBlock(&decl.getReductionRegion(),
+ decl.getReductionRegion().end(), {type, type},
+ {loc, loc});
+
+ return decl;
+}
+
+/// Creates an OpenMP reduction declaration and inserts it into the provided
+/// symbol table. The declaration has a constant initializer with the neutral
+/// value `initValue`, and the reduction combiner carried over from `reduce`.
+/// TODO: Generalize this for non-integer types, add atomic region.
+static mlir::omp::ReductionDeclareOp
+createReductionDecl(fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+ const Fortran::parser::ProcedureDesignator &procDesignator,
+ mlir::Type type, mlir::Location loc) {
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ mlir::ModuleOp module = builder.getModule();
+
+ auto decl =
+ module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+ if (decl)
+ return decl;
- mlir::OpBuilder modBuilder(module.getBodyRegion());
+ decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
+ builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+ mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+ mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
- decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
- loc, reductionOpName, type);
- builder.createBlock(&decl.getInitializerRegion(),
- decl.getInitializerRegion().end(), {type}, {loc});
- builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
- mlir::Value init =
- getIntrinsicProcInitValue(loc, type, procDesignator, builder);
- builder.create<mlir::omp::YieldOp>(loc, init);
-
- builder.createBlock(&decl.getReductionRegion(),
- decl.getReductionRegion().end(), {type, type},
- {loc, loc});
-
- builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
- mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
- mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
-
- mlir::Value reductionOp;
- switch (getReductionType(procDesignator)) {
- case IntrinsicProc::MAX:
+ mlir::Value reductionOp;
+ if (const auto *name{
+ Fortran::parser::Unwrap<Fortran::parser::Name>(procDesignator)}) {
+ if (name->source == "max") {
reductionOp =
getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
builder, type, loc, op1, op2);
- break;
- case IntrinsicProc::MIN:
+ } else if (name->source == "min") {
reductionOp =
getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
builder, type, loc, op1, op2);
- break;
- case IntrinsicProc::IOR:
+ } else if (name->source == "ior") {
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
- break;
- case IntrinsicProc::IEOR:
+ } else if (name->source == "ieor") {
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
- break;
- case IntrinsicProc::IAND:
+ } else if (name->source == "iand") {
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
- break;
- default:
- llvm_unreachable(
- "Reduction of some intrinsic operators is not supported");
+ } else {
+ TODO(loc, "Reduction of some intrinsic operators is not supported");
}
-
- builder.create<mlir::omp::YieldOp>(loc, reductionOp);
- return decl;
}
- /// Creates an OpenMP reduction declaration and inserts it into the provided
- /// symbol table. The declaration has a constant initializer with the neutral
- /// value `initValue`, and the reduction combiner carried over from `reduce`.
- /// TODO: Generalize this for non-integer types, add atomic region.
- static mlir::omp::ReductionDeclareOp createReductionDecl(
- fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Type type, mlir::Location loc) {
- mlir::OpBuilder::InsertionGuard guard(builder);
- mlir::ModuleOp module = builder.getModule();
-
- auto decl =
- module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
- if (decl)
- return decl;
-
- mlir::OpBuilder modBuilder(module.getBodyRegion());
-
- decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
- loc, reductionOpName, type);
- builder.createBlock(&decl.getInitializerRegion(),
- decl.getInitializerRegion().end(), {type}, {loc});
- builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
- mlir::Value init = getIntrinsicOpInitValue(loc, type, intrinsicOp, builder);
- builder.create<mlir::omp::YieldOp>(loc, init);
+ builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+ return decl;
+}
- builder.createBlock(&decl.getReductionRegion(),
- decl.getReductionRegion().end(), {type, type},
- {loc, loc});
+/// Creates an OpenMP reduction declaration and inserts it into the provided
+/// symbol table. The declaration has a constant initializer with the neutral
+/// value `initValue`, and the reduction combiner carried over from `reduce`.
+/// TODO: Generalize this for non-integer types, add atomic region.
+static mlir::omp::ReductionDeclareOp createReductionDecl(
+ fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+ Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+ mlir::Type type, mlir::Location loc) {
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ mlir::ModuleOp module = builder.getModule();
- builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
- mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
- mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+ auto decl =
+ module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+ if (decl)
+ return decl;
- mlir::Value reductionOp;
- switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
- reductionOp =
- getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
- builder, type, loc, op1, op2);
- break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
- reductionOp =
- getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
- builder, type, loc, op1, op2);
- break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: {
- mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
- mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+ decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
+ builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+ mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+ mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
- mlir::Value andiOp =
- builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
+ mlir::Value reductionOp;
+ switch (intrinsicOp) {
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ reductionOp =
+ getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
+ builder, type, loc, op1, op2);
+ break;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ reductionOp =
+ getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
+ builder, type, loc, op1, op2);
+ break;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: {
+ mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+ mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
- reductionOp = builder.createConvert(loc, type, andiOp);
- break;
- }
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: {
- mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
- mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+ mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
- mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
+ reductionOp = builder.createConvert(loc, type, andiOp);
+ break;
+ }
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: {
+ mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+ mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
- reductionOp = builder.createConvert(loc, type, oriOp);
- break;
- }
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: {
- mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
- mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+ mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
- mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
+ reductionOp = builder.createConvert(loc, type, oriOp);
+ break;
+ }
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: {
+ mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+ mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
- reductionOp = builder.createConvert(loc, type, cmpiOp);
- break;
- }
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: {
- mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
- mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+ mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
- mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
+ reductionOp = builder.createConvert(loc, type, cmpiOp);
+ break;
+ }
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: {
+ mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+ mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
- reductionOp = builder.createConvert(loc, type, cmpiOp);
- break;
- }
- default:
- TODO(loc, "Reduction of some intrinsic operators is not supported");
- }
+ mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
- builder.create<mlir::omp::YieldOp>(loc, reductionOp);
- return decl;
+ reductionOp = builder.createConvert(loc, type, cmpiOp);
+ break;
}
-
- /// Creates a reduction declaration and associates it with an OpenMP block
- /// directive.
- static void addReductionDecl(
- mlir::Location currentLocation,
- Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpReductionClause &reduction,
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- mlir::omp::ReductionDeclareOp decl;
- const auto &redOperator{
- std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
- const auto &objectList{
- std::get<Fortran::parser::OmpObjectList>(reduction.t)};
- if (const auto &redDefinedOp =
- std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
- const auto &intrinsicOp{
- std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
- redDefinedOp->u)};
- switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
- break;
-
- default:
- TODO(currentLocation,
- "Reduction of some intrinsic operators is not supported");
- break;
- }
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- mlir::Value symVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
- symVal = declOp.getBase();
- mlir::Type redType =
- symVal.getType().cast<fir::ReferenceType>().getEleTy();
- reductionVars.push_back(symVal);
- if (redType.isa<fir::LogicalType>())
- decl = createReductionDecl(
- firOpBuilder,
- getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
- intrinsicOp, redType, currentLocation);
- else if (redType.isIntOrIndexOrFloat()) {
- decl = createReductionDecl(firOpBuilder,
- getReductionName(intrinsicOp, redType),
- intrinsicOp, redType, currentLocation);
- } else {
- TODO(currentLocation, "Reduction of some types is not supported");
- }
- reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
- firOpBuilder.getContext(), decl.getSymName()));
- }
- }
- }
- } else if (const auto *reductionIntrinsic =
- std::get_if<Fortran::parser::ProcedureDesignator>(
- &redOperator.u)) {
- if (ReductionProcessor::supportedIntrinsicProcReduction(
- *reductionIntrinsic)) {
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- mlir::Value symVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
- symVal = declOp.getBase();
- mlir::Type redType =
- symVal.getType().cast<fir::ReferenceType>().getEleTy();
- reductionVars.push_back(symVal);
- assert(redType.isIntOrIndexOrFloat() &&
- "Unsupported reduction type");
- decl = createReductionDecl(
- firOpBuilder,
- getReductionName(getRealName(*reductionIntrinsic).ToString(),
- redType),
- *reductionIntrinsic, redType, currentLocation);
- reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
- firOpBuilder.getContext(), decl.getSymName()));
- }
- }
- }
- }
- }
+ default:
+ TODO(loc, "Reduction of some intrinsic operators is not supported");
}
-};
+
+ builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+ return decl;
+}
static mlir::omp::ScheduleModifier
translateScheduleModifier(const Fortran::parser::OmpScheduleModifierType &m) {
@@ -1302,6 +1137,101 @@ static mlir::Value getIfClauseOperand(
ifVal);
}
+/// Creates a reduction declaration and associates it with an OpenMP block
+/// directive.
+static void
+addReductionDecl(mlir::Location currentLocation,
+ Fortran::lower::AbstractConverter &converter,
+ const Fortran::parser::OmpReductionClause &reduction,
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ mlir::omp::ReductionDeclareOp decl;
+ const auto &redOperator{
+ std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
+ const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+ if (const auto &redDefinedOp =
+ std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
+ const auto &intrinsicOp{
+ std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
+ redDefinedOp->u)};
+ switch (intrinsicOp) {
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ break;
+
+ default:
+ TODO(currentLocation,
+ "Reduction of some intrinsic operators is not supported");
+ break;
+ }
+ for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+ if (const auto *name{
+ Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+ if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+ symVal = declOp.getBase();
+ mlir::Type redType =
+ symVal.getType().cast<fir::ReferenceType>().getEleTy();
+ reductionVars.push_back(symVal);
+ if (redType.isa<fir::LogicalType>())
+ decl = createReductionDecl(
+ firOpBuilder,
+ getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
+ intrinsicOp, redType, currentLocation);
+ else if (redType.isIntOrIndexOrFloat()) {
+ decl = createReductionDecl(firOpBuilder,
+ getReductionName(intrinsicOp, redType),
+ intrinsicOp, redType, currentLocation);
+ } else {
+ TODO(currentLocation, "Reduction of some types is not supported");
+ }
+ reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+ firOpBuilder.getContext(), decl.getSymName()));
+ }
+ }
+ }
+ } else if (const auto *reductionIntrinsic =
+ std::get_if<Fortran::parser::ProcedureDesignator>(
+ &redOperator.u)) {
+ if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
+ reductionIntrinsic)}) {
+ if ((name->source != "max") && (name->source != "min") &&
+ (name->source != "ior") && (name->source != "ieor") &&
+ (name->source != "iand")) {
+ TODO(currentLocation,
+ "Reduction of intrinsic procedures is not supported");
+ }
+ std::string intrinsicOp = name->ToString();
+ for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+ if (const auto *name{
+ Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+ if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+ symVal = declOp.getBase();
+ mlir::Type redType =
+ symVal.getType().cast<fir::ReferenceType>().getEleTy();
+ reductionVars.push_back(symVal);
+ assert(redType.isIntOrIndexOrFloat() &&
+ "Unsupported reduction type");
+ decl = createReductionDecl(
+ firOpBuilder, getReductionName(intrinsicOp, redType),
+ *reductionIntrinsic, redType, currentLocation);
+ reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+ firOpBuilder.getContext(), decl.getSymName()));
+ }
+ }
+ }
+ }
+ }
+}
+
static void
addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpObjectList &useDeviceClause,
@@ -1898,9 +1828,8 @@ bool ClauseProcessor::processReduction(
return findRepeatableClause<ClauseTy::Reduction>(
[&](const ClauseTy::Reduction *reductionClause,
const Fortran::parser::CharBlock &) {
- ReductionProcessor rp;
- rp.addReductionDecl(currentLocation, converter, reductionClause->v,
- reductionVars, reductionDeclSymbols);
+ addReductionDecl(currentLocation, converter, reductionClause->v,
+ reductionVars, reductionDeclSymbols);
});
}
@@ -3736,50 +3665,48 @@ void Fortran::lower::genOpenMPReduction(
} else if (const auto *reductionIntrinsic =
std::get_if<Fortran::parser::ProcedureDesignator>(
&redOperator.u)) {
- if (!ReductionProcessor::supportedIntrinsicProcReduction(
- *reductionIntrinsic))
- continue;
- ReductionProcessor::IntrinsicProc redIntrinsicProc =
- ReductionProcessor::getReductionType(*reductionIntrinsic);
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
- reductionVal = declOp.getBase();
- for (const mlir::OpOperand &reductionValUse :
- reductionVal.getUses()) {
- if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
- reductionValUse.getOwner())) {
- mlir::Value loadVal = loadOp.getRes();
- // Max is lowered as a compare -> select.
- // Match the pattern here.
- mlir::Operation *reductionOp =
- findReductionChain(loadVal, &reductionVal);
- if (reductionOp == nullptr)
- continue;
-
- if (redIntrinsicProc ==
- ReductionProcessor::IntrinsicProc::MAX ||
- redIntrinsicProc ==
- ReductionProcessor::IntrinsicProc::MIN) {
- assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
- "Selection Op not found in reduction intrinsic");
- mlir::Operation *compareOp =
- getCompareFromReductionOp(reductionOp, loadVal);
- updateReduction(compareOp, firOpBuilder, loadVal,
- reductionVal);
- }
- if (redIntrinsicProc ==
- ReductionProcessor::IntrinsicProc::IOR ||
- redIntrinsicProc ==
- ReductionProcessor::IntrinsicProc::IEOR ||
- redIntrinsicProc ==
- ReductionProcessor::IntrinsicProc::IAND) {
+ if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
+ reductionIntrinsic)}) {
+ std::string redName = name->ToString();
+ if ((name->source != "max") && (name->source != "min") &&
+ (name->source != "ior") && (name->source != "ieor") &&
+ (name->source != "iand")) {
+ continue;
+ }
+ for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+ if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
+ ompObject)}) {
+ if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp =
+ reductionVal.getDefiningOp<hlfir::DeclareOp>())
+ reductionVal = declOp.getBase();
+ for (const mlir::OpOperand &reductionValUse :
+ reductionVal.getUses()) {
+ if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
+ reductionValUse.getOwner())) {
+ mlir::Value loadVal = loadOp.getRes();
+ // Max is lowered as a compare -> select.
+ // Match the pattern here.
+ mlir::Operation *reductionOp =
+ findReductionChain(loadVal, &reductionVal);
+ if (reductionOp == nullptr)
+ continue;
+
+ if (redName == "max" || redName == "min") {
+ assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
+ "Selection Op not found in reduction intrinsic");
+ mlir::Operation *compareOp =
+ getCompareFromReductionOp(reductionOp, loadVal);
+ updateReduction(compareOp, firOpBuilder, loadVal,
+ reductionVal);
+ }
+ if (redName == "ior" || redName == "ieor" ||
+ redName == "iand") {
- updateReduction(reductionOp, firOpBuilder, loadVal,
- reductionVal);
+ updateReduction(reductionOp, firOpBuilder, loadVal,
+ reductionVal);
+ }
}
}
}
More information about the flang-commits
mailing list