[flang-commits] [flang] 5784199 - [NFC][Flang][OpenMP] Refactor OpenMP.cpp::genOpenMPReduction
Dylan Fleming via flang-commits
flang-commits at lists.llvm.org
Mon Aug 8 07:27:47 PDT 2022
Author: Dylan Fleming
Date: 2022-08-08T14:27:32Z
New Revision: 5784199dd1c12be098d3ac8ca7b91dc719539bfd
URL: https://github.com/llvm/llvm-project/commit/5784199dd1c12be098d3ac8ca7b91dc719539bfd
DIFF: https://github.com/llvm/llvm-project/commit/5784199dd1c12be098d3ac8ca7b91dc719539bfd.diff
LOG: [NFC][Flang][OpenMP] Refactor OpenMP.cpp::genOpenMPReduction
This patch serves two main purposes:
Firstly, to split some of the logic into a seperate method
to try and improve readability
On top of this, it aims to make creating the reductions more generic.
That way, subsequent patches adding reductions shouldn't need
to add a significant amount of extra logic checks, such as checking
for specific operators.
Reviewed By: awarzynski
Differential Revision: https://reviews.llvm.org/D131161
Added:
Modified:
flang/include/flang/Lower/OpenMP.h
flang/lib/Lower/OpenMP.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h
index aa1b43b89eace..f2a46061425ac 100644
--- a/flang/include/flang/Lower/OpenMP.h
+++ b/flang/include/flang/Lower/OpenMP.h
@@ -15,6 +15,15 @@
#include <cinttypes>
+namespace mlir {
+class Value;
+class Operation;
+} // namespace mlir
+
+namespace fir {
+class FirOpBuilder;
+} // namespace fir
+
namespace Fortran {
namespace parser {
struct OpenMPConstruct;
@@ -41,6 +50,11 @@ void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
void genOpenMPReduction(AbstractConverter &,
const Fortran::parser::OmpClauseList &clauseList);
+void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value,
+ mlir::Value);
+
+mlir::Operation *getReductionInChain(mlir::Value, mlir::Value);
+
} // namespace lower
} // namespace Fortran
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index c7dc6f5ed02a4..ea8b239c75a7d 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1633,42 +1633,19 @@ void Fortran::lower::genOpenMPReduction(
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const auto *symbol{name->symbol}) {
- mlir::Value symVal = converter.getSymbolAddress(*symbol);
- mlir::Type redType =
- symVal.getType().cast<fir::ReferenceType>().getEleTy();
- if (!redType.isIntOrIndex())
+ mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+ mlir::Type reductionType =
+ reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
+ if (!reductionType.isIntOrIndex())
continue;
- for (mlir::OpOperand &use1 : symVal.getUses()) {
- if (auto load = mlir::dyn_cast<fir::LoadOp>(use1.getOwner())) {
- mlir::Value loadVal = load.getRes();
- for (mlir::OpOperand &use2 : loadVal.getUses()) {
- if (auto add = mlir::dyn_cast<mlir::arith::AddIOp>(
- use2.getOwner())) {
- mlir::Value addRes = add.getResult();
- for (mlir::OpOperand &use3 : addRes.getUses()) {
- if (auto store =
- mlir::dyn_cast<fir::StoreOp>(use3.getOwner())) {
- if (store.getMemref() == symVal) {
- // Chain found! Now replace load->reduction->store
- // with the OpenMP reduction operation.
- mlir::OpBuilder::InsertPoint insertPtDel =
- firOpBuilder.saveInsertionPoint();
- firOpBuilder.setInsertionPoint(add);
- if (add.getLhs() == loadVal) {
- firOpBuilder.create<mlir::omp::ReductionOp>(
- add.getLoc(), add.getRhs(), symVal);
- } else {
- firOpBuilder.create<mlir::omp::ReductionOp>(
- add.getLoc(), add.getLhs(), symVal);
- }
- store.erase();
- add.erase();
- load.erase();
- firOpBuilder.restoreInsertionPoint(insertPtDel);
- }
- }
- }
- }
+
+ for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
+
+ if (auto loadOp =
+ mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
+ mlir::Value loadVal = loadOp.getRes();
+ if (auto reductionOp = getReductionInChain(reductionVal, loadVal)) {
+ updateReduction(reductionOp, firOpBuilder, loadVal, reductionVal);
}
}
}
@@ -1679,3 +1656,42 @@ void Fortran::lower::genOpenMPReduction(
}
}
}
+
+// Checks whether loadVal is used in an operation,
+// the result of which is then stored into reductionVal.
+// If yes, then the operation corresponding to the reduction is returned.
+// loadVal is assumed to be the value of a load operation
+// reductionVal is the results of an OpenMP reduction operation.
+mlir::Operation *Fortran::lower::getReductionInChain(mlir::Value reductionVal,
+ mlir::Value loadVal) {
+ for (mlir::OpOperand &loadUse : loadVal.getUses()) {
+ if (auto reductionOp = loadUse.getOwner()) {
+ for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) {
+ if (auto store =
+ mlir::dyn_cast<fir::StoreOp>(reductionOperand.getOwner())) {
+ if (store.getMemref() == reductionVal) {
+ store.erase();
+ return reductionOp;
+ }
+ }
+ }
+ }
+ }
+ return nullptr;
+}
+
+void Fortran::lower::updateReduction(mlir::Operation *op,
+ fir::FirOpBuilder &firOpBuilder,
+ mlir::Value loadVal, mlir::Value reductionVal) {
+ mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint();
+ firOpBuilder.setInsertionPoint(op);
+
+ if (op->getOperand(0) == loadVal)
+ firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), op->getOperand(1),
+ reductionVal);
+ else
+ firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), op->getOperand(0),
+ reductionVal);
+
+ firOpBuilder.restoreInsertionPoint(insertPtDel);
+}
More information about the flang-commits
mailing list