[flang-commits] [flang] 5784199 - [NFC][Flang][OpenMP] Refactor OpenMP.cpp::genOpenMPReduction

Dylan Fleming via flang-commits flang-commits at lists.llvm.org
Mon Aug 8 07:27:47 PDT 2022


Author: Dylan Fleming
Date: 2022-08-08T14:27:32Z
New Revision: 5784199dd1c12be098d3ac8ca7b91dc719539bfd

URL: https://github.com/llvm/llvm-project/commit/5784199dd1c12be098d3ac8ca7b91dc719539bfd
DIFF: https://github.com/llvm/llvm-project/commit/5784199dd1c12be098d3ac8ca7b91dc719539bfd.diff

LOG: [NFC][Flang][OpenMP] Refactor OpenMP.cpp::genOpenMPReduction

This patch serves two main purposes:
Firstly, to split some of the logic into a seperate method
to try and improve readability

On top of this, it aims to make creating the reductions more generic.
That way, subsequent patches adding reductions shouldn't need
to add a significant amount of extra logic checks, such as checking
for specific operators.

Reviewed By: awarzynski

Differential Revision: https://reviews.llvm.org/D131161

Added: 
    

Modified: 
    flang/include/flang/Lower/OpenMP.h
    flang/lib/Lower/OpenMP.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h
index aa1b43b89eace..f2a46061425ac 100644
--- a/flang/include/flang/Lower/OpenMP.h
+++ b/flang/include/flang/Lower/OpenMP.h
@@ -15,6 +15,15 @@
 
 #include <cinttypes>
 
+namespace mlir {
+class Value;
+class Operation;
+} // namespace mlir
+
+namespace fir {
+class FirOpBuilder;
+} // namespace fir
+
 namespace Fortran {
 namespace parser {
 struct OpenMPConstruct;
@@ -41,6 +50,11 @@ void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
 void genOpenMPReduction(AbstractConverter &,
                         const Fortran::parser::OmpClauseList &clauseList);
 
+void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value,
+                     mlir::Value);
+
+mlir::Operation *getReductionInChain(mlir::Value, mlir::Value);
+
 } // namespace lower
 } // namespace Fortran
 

diff  --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index c7dc6f5ed02a4..ea8b239c75a7d 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1633,42 +1633,19 @@ void Fortran::lower::genOpenMPReduction(
           if (const auto *name{
                   Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
             if (const auto *symbol{name->symbol}) {
-              mlir::Value symVal = converter.getSymbolAddress(*symbol);
-              mlir::Type redType =
-                  symVal.getType().cast<fir::ReferenceType>().getEleTy();
-              if (!redType.isIntOrIndex())
+              mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+              mlir::Type reductionType =
+                  reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
+              if (!reductionType.isIntOrIndex())
                 continue;
-              for (mlir::OpOperand &use1 : symVal.getUses()) {
-                if (auto load = mlir::dyn_cast<fir::LoadOp>(use1.getOwner())) {
-                  mlir::Value loadVal = load.getRes();
-                  for (mlir::OpOperand &use2 : loadVal.getUses()) {
-                    if (auto add = mlir::dyn_cast<mlir::arith::AddIOp>(
-                            use2.getOwner())) {
-                      mlir::Value addRes = add.getResult();
-                      for (mlir::OpOperand &use3 : addRes.getUses()) {
-                        if (auto store =
-                                mlir::dyn_cast<fir::StoreOp>(use3.getOwner())) {
-                          if (store.getMemref() == symVal) {
-                            // Chain found! Now replace load->reduction->store
-                            // with the OpenMP reduction operation.
-                            mlir::OpBuilder::InsertPoint insertPtDel =
-                                firOpBuilder.saveInsertionPoint();
-                            firOpBuilder.setInsertionPoint(add);
-                            if (add.getLhs() == loadVal) {
-                              firOpBuilder.create<mlir::omp::ReductionOp>(
-                                  add.getLoc(), add.getRhs(), symVal);
-                            } else {
-                              firOpBuilder.create<mlir::omp::ReductionOp>(
-                                  add.getLoc(), add.getLhs(), symVal);
-                            }
-                            store.erase();
-                            add.erase();
-                            load.erase();
-                            firOpBuilder.restoreInsertionPoint(insertPtDel);
-                          }
-                        }
-                      }
-                    }
+
+              for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
+
+                if (auto loadOp =
+                        mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
+                  mlir::Value loadVal = loadOp.getRes();
+                  if (auto reductionOp = getReductionInChain(reductionVal, loadVal)) {
+                    updateReduction(reductionOp, firOpBuilder, loadVal, reductionVal);
                   }
                 }
               }
@@ -1679,3 +1656,42 @@ void Fortran::lower::genOpenMPReduction(
     }
   }
 }
+
+// Checks whether loadVal is used in an operation,
+// the result of which is then stored into reductionVal. 
+// If yes, then the operation corresponding to the reduction is returned. 
+// loadVal is assumed to be the value of a load operation
+// reductionVal is the results of an OpenMP reduction operation.
+mlir::Operation *Fortran::lower::getReductionInChain(mlir::Value reductionVal,
+                                                     mlir::Value loadVal) {
+  for (mlir::OpOperand &loadUse : loadVal.getUses()) {
+    if (auto reductionOp = loadUse.getOwner()) {
+      for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) {
+        if (auto store =
+                mlir::dyn_cast<fir::StoreOp>(reductionOperand.getOwner())) {
+          if (store.getMemref() == reductionVal) {
+            store.erase();
+            return reductionOp;
+          }
+        }
+      }
+    }
+  }
+  return nullptr;
+}
+
+void Fortran::lower::updateReduction(mlir::Operation *op,
+                                     fir::FirOpBuilder &firOpBuilder,
+                                     mlir::Value loadVal, mlir::Value reductionVal) {
+  mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint();
+  firOpBuilder.setInsertionPoint(op);
+
+  if (op->getOperand(0) == loadVal)
+    firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), op->getOperand(1),
+                                                reductionVal);
+  else
+    firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), op->getOperand(0),
+                                                reductionVal);
+
+  firOpBuilder.restoreInsertionPoint(insertPtDel);
+}


        


More information about the flang-commits mailing list