[flang-commits] [flang] [flang][OpenMP] Make several function local to OpenMP.cpp, NFC (PR #86726)
Krzysztof Parzyszek via flang-commits
flang-commits at lists.llvm.org
Tue Mar 26 13:04:58 PDT 2024
https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/86726
There were several functions, mostly reduction-related, that were only called from OpenMP.cpp. Remove them from OpenMP.h, and make them local in OpenMP.cpp:
- genOpenMPReduction
- findReductionChain
- getConvertFromReductionOp
- updateReduction
- removeStoreOp
Also, move the function bodies out of the "public" section.
>From edd931dabfdf283db16f2fb2510c51f56127d9bb Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Tue, 26 Mar 2024 13:39:49 -0500
Subject: [PATCH] [flang][OpenMP] Make several function local to OpenMP.cpp,
NFC
There were several functions, mostly reduction-related, that were only called
from OpenMP.cpp. Remove them from OpenMP.h, and make them local in OpenMP.cpp:
- genOpenMPReduction
- findReductionChain
- getConvertFromReductionOp
- updateReduction
- removeStoreOp
Also, move the function bodies out of the "public" section.
---
flang/include/flang/Lower/OpenMP.h | 12 -
flang/lib/Lower/OpenMP/OpenMP.cpp | 417 ++++++++++++++---------------
2 files changed, 207 insertions(+), 222 deletions(-)
diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h
index 3b22a652d1fc1e..6e150ef4e8e82f 100644
--- a/flang/include/flang/Lower/OpenMP.h
+++ b/flang/include/flang/Lower/OpenMP.h
@@ -19,7 +19,6 @@
#include <utility>
namespace mlir {
-class Value;
class Operation;
class Location;
namespace omp {
@@ -30,7 +29,6 @@ enum class DeclareTargetCaptureClause : uint32_t;
namespace fir {
class FirOpBuilder;
-class ConvertOp;
} // namespace fir
namespace Fortran {
@@ -84,16 +82,6 @@ void genOpenMPSymbolProperties(AbstractConverter &converter,
int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
void genDeclareTargetIntGlobal(AbstractConverter &, const pft::Variable &);
-void genOpenMPReduction(AbstractConverter &,
- Fortran::semantics::SemanticsContext &,
- const Fortran::parser::OmpClauseList &clauseList);
-
-mlir::Operation *findReductionChain(mlir::Value, mlir::Value * = nullptr);
-fir::ConvertOp getConvertFromReductionOp(mlir::Operation *, mlir::Value);
-void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value,
- mlir::Value, fir::ConvertOp * = nullptr);
-void removeStoreOp(mlir::Operation *, mlir::Value);
-
bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);
bool isOpenMPDeviceDeclareTarget(Fortran::lower::AbstractConverter &,
Fortran::semantics::SemanticsContext &,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 0cf2a8f97040a8..0a728b65afbf06 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -237,6 +237,213 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
return storeOp;
}
+static mlir::Operation *
+findReductionChain(mlir::Value loadVal, mlir::Value *reductionVal = nullptr) {
+ for (mlir::OpOperand &loadOperand : loadVal.getUses()) {
+ if (mlir::Operation *reductionOp = loadOperand.getOwner()) {
+ if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(reductionOp)) {
+ for (mlir::OpOperand &convertOperand : convertOp.getRes().getUses()) {
+ if (mlir::Operation *reductionOp = convertOperand.getOwner())
+ return reductionOp;
+ }
+ }
+ for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) {
+ if (auto store =
+ mlir::dyn_cast<fir::StoreOp>(reductionOperand.getOwner())) {
+ if (store.getMemref() == *reductionVal) {
+ store.erase();
+ return reductionOp;
+ }
+ }
+ if (auto assign =
+ mlir::dyn_cast<hlfir::AssignOp>(reductionOperand.getOwner())) {
+ if (assign.getLhs() == *reductionVal) {
+ assign.erase();
+ return reductionOp;
+ }
+ }
+ }
+ }
+ }
+ return nullptr;
+}
+
+// for a logical operator 'op' reduction X = X op Y
+// This function returns the operation responsible for converting Y from
+// fir.logical<4> to i1
+static fir::ConvertOp getConvertFromReductionOp(mlir::Operation *reductionOp,
+ mlir::Value loadVal) {
+ for (mlir::Value reductionOperand : reductionOp->getOperands()) {
+ if (auto convertOp =
+ mlir::dyn_cast<fir::ConvertOp>(reductionOperand.getDefiningOp())) {
+ if (convertOp.getOperand() == loadVal)
+ continue;
+ return convertOp;
+ }
+ }
+ return nullptr;
+}
+
+static void updateReduction(mlir::Operation *op,
+ fir::FirOpBuilder &firOpBuilder,
+ mlir::Value loadVal, mlir::Value reductionVal,
+ fir::ConvertOp *convertOp = nullptr) {
+ mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint();
+ firOpBuilder.setInsertionPoint(op);
+
+ mlir::Value reductionOp;
+ if (convertOp)
+ reductionOp = convertOp->getOperand();
+ else if (op->getOperand(0) == loadVal)
+ reductionOp = op->getOperand(1);
+ else
+ reductionOp = op->getOperand(0);
+
+ firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), reductionOp,
+ reductionVal);
+ firOpBuilder.restoreInsertionPoint(insertPtDel);
+}
+
+static void removeStoreOp(mlir::Operation *reductionOp, mlir::Value symVal) {
+ for (mlir::Operation *reductionOpUse : reductionOp->getUsers()) {
+ if (auto convertReduction =
+ mlir::dyn_cast<fir::ConvertOp>(reductionOpUse)) {
+ for (mlir::Operation *convertReductionUse :
+ convertReduction.getRes().getUsers()) {
+ if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(convertReductionUse)) {
+ if (storeOp.getMemref() == symVal)
+ storeOp.erase();
+ }
+ if (auto assignOp =
+ mlir::dyn_cast<hlfir::AssignOp>(convertReductionUse)) {
+ if (assignOp.getLhs() == symVal)
+ assignOp.erase();
+ }
+ }
+ }
+ }
+}
+
+// Generate an OpenMP reduction operation.
+// TODO: Currently assumes it is either an integer addition/multiplication
+// reduction, or a logical and reduction. Generalize this for various reduction
+// operation types.
+// TODO: Generate the reduction operation during lowering instead of creating
+// and removing operations since this is not a robust approach. Also, removing
+// ops in the builder (instead of a rewriter) is probably not the best approach.
+static void genOpenMPReduction(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const Fortran::parser::OmpClauseList &clauseList) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+ List<Clause> clauses{makeList(clauseList, semaCtx)};
+
+ for (const Clause &clause : clauses) {
+ if (const auto &reductionClause =
+ std::get_if<clause::Reduction>(&clause.u)) {
+ const auto &redOperatorList{
+ std::get<clause::Reduction::ReductionIdentifiers>(
+ reductionClause->t)};
+ assert(redOperatorList.size() == 1 && "Expecting single operator");
+ const auto &redOperator = redOperatorList.front();
+ const auto &objects{std::get<ObjectList>(reductionClause->t)};
+ if (const auto *reductionOp =
+ std::get_if<clause::DefinedOperator>(&redOperator.u)) {
+ const auto &intrinsicOp{
+ std::get<clause::DefinedOperator::IntrinsicOperator>(
+ reductionOp->u)};
+
+ switch (intrinsicOp) {
+ case clause::DefinedOperator::IntrinsicOperator::Add:
+ case clause::DefinedOperator::IntrinsicOperator::Multiply:
+ case clause::DefinedOperator::IntrinsicOperator::AND:
+ case clause::DefinedOperator::IntrinsicOperator::EQV:
+ case clause::DefinedOperator::IntrinsicOperator::OR:
+ case clause::DefinedOperator::IntrinsicOperator::NEQV:
+ break;
+ default:
+ continue;
+ }
+ for (const Object &object : objects) {
+ if (const Fortran::semantics::Symbol *symbol = object.id()) {
+ mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
+ reductionVal = declOp.getBase();
+ mlir::Type reductionType =
+ reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
+ if (!reductionType.isa<fir::LogicalType>()) {
+ if (!reductionType.isIntOrIndexOrFloat())
+ continue;
+ }
+ for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
+ if (auto loadOp =
+ mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
+ mlir::Value loadVal = loadOp.getRes();
+ if (reductionType.isa<fir::LogicalType>()) {
+ mlir::Operation *reductionOp = findReductionChain(loadVal);
+ fir::ConvertOp convertOp =
+ getConvertFromReductionOp(reductionOp, loadVal);
+ updateReduction(reductionOp, firOpBuilder, loadVal,
+ reductionVal, &convertOp);
+ removeStoreOp(reductionOp, reductionVal);
+ } else if (mlir::Operation *reductionOp =
+ findReductionChain(loadVal, &reductionVal)) {
+ updateReduction(reductionOp, firOpBuilder, loadVal,
+ reductionVal);
+ }
+ }
+ }
+ }
+ }
+ } else if (const auto *reductionIntrinsic =
+ std::get_if<clause::ProcedureDesignator>(&redOperator.u)) {
+ if (!ReductionProcessor::supportedIntrinsicProcReduction(
+ *reductionIntrinsic))
+ continue;
+ ReductionProcessor::ReductionIdentifier redId =
+ ReductionProcessor::getReductionType(*reductionIntrinsic);
+ for (const Object &object : objects) {
+ if (const Fortran::semantics::Symbol *symbol = object.id()) {
+ mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
+ reductionVal = declOp.getBase();
+ for (const mlir::OpOperand &reductionValUse :
+ reductionVal.getUses()) {
+ if (auto loadOp =
+ mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
+ mlir::Value loadVal = loadOp.getRes();
+ // Max is lowered as a compare -> select.
+ // Match the pattern here.
+ mlir::Operation *reductionOp =
+ findReductionChain(loadVal, &reductionVal);
+ if (reductionOp == nullptr)
+ continue;
+
+ if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
+ redId == ReductionProcessor::ReductionIdentifier::MIN) {
+ assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
+ "Selection Op not found in reduction intrinsic");
+ mlir::Operation *compareOp =
+ getCompareFromReductionOp(reductionOp, loadVal);
+ updateReduction(compareOp, firOpBuilder, loadVal,
+ reductionVal);
+ }
+ if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
+ redId == ReductionProcessor::ReductionIdentifier::IEOR ||
+ redId == ReductionProcessor::ReductionIdentifier::IAND) {
+ updateReduction(reductionOp, firOpBuilder, loadVal,
+ reductionVal);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
struct OpWithBodyGenInfo {
/// A type for a code-gen callback function. This takes as argument the op for
/// which the code is being generated and returns the arguments of the op's
@@ -2339,216 +2546,6 @@ void Fortran::lower::genDeclareTargetIntGlobal(
}
}
-// Generate an OpenMP reduction operation.
-// TODO: Currently assumes it is either an integer addition/multiplication
-// reduction, or a logical and reduction. Generalize this for various reduction
-// operation types.
-// TODO: Generate the reduction operation during lowering instead of creating
-// and removing operations since this is not a robust approach. Also, removing
-// ops in the builder (instead of a rewriter) is probably not the best approach.
-void Fortran::lower::genOpenMPReduction(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauseList) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
- List<Clause> clauses{makeList(clauseList, semaCtx)};
-
- for (const Clause &clause : clauses) {
- if (const auto &reductionClause =
- std::get_if<clause::Reduction>(&clause.u)) {
- const auto &redOperatorList{
- std::get<clause::Reduction::ReductionIdentifiers>(
- reductionClause->t)};
- assert(redOperatorList.size() == 1 && "Expecting single operator");
- const auto &redOperator = redOperatorList.front();
- const auto &objects{std::get<ObjectList>(reductionClause->t)};
- if (const auto *reductionOp =
- std::get_if<clause::DefinedOperator>(&redOperator.u)) {
- const auto &intrinsicOp{
- std::get<clause::DefinedOperator::IntrinsicOperator>(
- reductionOp->u)};
-
- switch (intrinsicOp) {
- case clause::DefinedOperator::IntrinsicOperator::Add:
- case clause::DefinedOperator::IntrinsicOperator::Multiply:
- case clause::DefinedOperator::IntrinsicOperator::AND:
- case clause::DefinedOperator::IntrinsicOperator::EQV:
- case clause::DefinedOperator::IntrinsicOperator::OR:
- case clause::DefinedOperator::IntrinsicOperator::NEQV:
- break;
- default:
- continue;
- }
- for (const Object &object : objects) {
- if (const Fortran::semantics::Symbol *symbol = object.id()) {
- mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
- reductionVal = declOp.getBase();
- mlir::Type reductionType =
- reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
- if (!reductionType.isa<fir::LogicalType>()) {
- if (!reductionType.isIntOrIndexOrFloat())
- continue;
- }
- for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
- if (auto loadOp =
- mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
- mlir::Value loadVal = loadOp.getRes();
- if (reductionType.isa<fir::LogicalType>()) {
- mlir::Operation *reductionOp = findReductionChain(loadVal);
- fir::ConvertOp convertOp =
- getConvertFromReductionOp(reductionOp, loadVal);
- updateReduction(reductionOp, firOpBuilder, loadVal,
- reductionVal, &convertOp);
- removeStoreOp(reductionOp, reductionVal);
- } else if (mlir::Operation *reductionOp =
- findReductionChain(loadVal, &reductionVal)) {
- updateReduction(reductionOp, firOpBuilder, loadVal,
- reductionVal);
- }
- }
- }
- }
- }
- } else if (const auto *reductionIntrinsic =
- std::get_if<clause::ProcedureDesignator>(&redOperator.u)) {
- if (!ReductionProcessor::supportedIntrinsicProcReduction(
- *reductionIntrinsic))
- continue;
- ReductionProcessor::ReductionIdentifier redId =
- ReductionProcessor::getReductionType(*reductionIntrinsic);
- for (const Object &object : objects) {
- if (const Fortran::semantics::Symbol *symbol = object.id()) {
- mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
- reductionVal = declOp.getBase();
- for (const mlir::OpOperand &reductionValUse :
- reductionVal.getUses()) {
- if (auto loadOp =
- mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
- mlir::Value loadVal = loadOp.getRes();
- // Max is lowered as a compare -> select.
- // Match the pattern here.
- mlir::Operation *reductionOp =
- findReductionChain(loadVal, &reductionVal);
- if (reductionOp == nullptr)
- continue;
-
- if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
- redId == ReductionProcessor::ReductionIdentifier::MIN) {
- assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
- "Selection Op not found in reduction intrinsic");
- mlir::Operation *compareOp =
- getCompareFromReductionOp(reductionOp, loadVal);
- updateReduction(compareOp, firOpBuilder, loadVal,
- reductionVal);
- }
- if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
- redId == ReductionProcessor::ReductionIdentifier::IEOR ||
- redId == ReductionProcessor::ReductionIdentifier::IAND) {
- updateReduction(reductionOp, firOpBuilder, loadVal,
- reductionVal);
- }
- }
- }
- }
- }
- }
- }
- }
-}
-
-mlir::Operation *Fortran::lower::findReductionChain(mlir::Value loadVal,
- mlir::Value *reductionVal) {
- for (mlir::OpOperand &loadOperand : loadVal.getUses()) {
- if (mlir::Operation *reductionOp = loadOperand.getOwner()) {
- if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(reductionOp)) {
- for (mlir::OpOperand &convertOperand : convertOp.getRes().getUses()) {
- if (mlir::Operation *reductionOp = convertOperand.getOwner())
- return reductionOp;
- }
- }
- for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) {
- if (auto store =
- mlir::dyn_cast<fir::StoreOp>(reductionOperand.getOwner())) {
- if (store.getMemref() == *reductionVal) {
- store.erase();
- return reductionOp;
- }
- }
- if (auto assign =
- mlir::dyn_cast<hlfir::AssignOp>(reductionOperand.getOwner())) {
- if (assign.getLhs() == *reductionVal) {
- assign.erase();
- return reductionOp;
- }
- }
- }
- }
- }
- return nullptr;
-}
-
-// for a logical operator 'op' reduction X = X op Y
-// This function returns the operation responsible for converting Y from
-// fir.logical<4> to i1
-fir::ConvertOp
-Fortran::lower::getConvertFromReductionOp(mlir::Operation *reductionOp,
- mlir::Value loadVal) {
- for (mlir::Value reductionOperand : reductionOp->getOperands()) {
- if (auto convertOp =
- mlir::dyn_cast<fir::ConvertOp>(reductionOperand.getDefiningOp())) {
- if (convertOp.getOperand() == loadVal)
- continue;
- return convertOp;
- }
- }
- return nullptr;
-}
-
-void Fortran::lower::updateReduction(mlir::Operation *op,
- fir::FirOpBuilder &firOpBuilder,
- mlir::Value loadVal,
- mlir::Value reductionVal,
- fir::ConvertOp *convertOp) {
- mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint();
- firOpBuilder.setInsertionPoint(op);
-
- mlir::Value reductionOp;
- if (convertOp)
- reductionOp = convertOp->getOperand();
- else if (op->getOperand(0) == loadVal)
- reductionOp = op->getOperand(1);
- else
- reductionOp = op->getOperand(0);
-
- firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), reductionOp,
- reductionVal);
- firOpBuilder.restoreInsertionPoint(insertPtDel);
-}
-
-void Fortran::lower::removeStoreOp(mlir::Operation *reductionOp,
- mlir::Value symVal) {
- for (mlir::Operation *reductionOpUse : reductionOp->getUsers()) {
- if (auto convertReduction =
- mlir::dyn_cast<fir::ConvertOp>(reductionOpUse)) {
- for (mlir::Operation *convertReductionUse :
- convertReduction.getRes().getUsers()) {
- if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(convertReductionUse)) {
- if (storeOp.getMemref() == symVal)
- storeOp.erase();
- }
- if (auto assignOp =
- mlir::dyn_cast<hlfir::AssignOp>(convertReductionUse)) {
- if (assignOp.getLhs() == symVal)
- assignOp.erase();
- }
- }
- }
- }
-}
-
bool Fortran::lower::isOpenMPTargetConstruct(
const Fortran::parser::OpenMPConstruct &omp) {
llvm::omp::Directive dir = llvm::omp::Directive::OMPD_unknown;
More information about the flang-commits
mailing list