[flang-commits] [flang] [flang][OpenMP] Make several function local to OpenMP.cpp, NFC (PR #86726)

Krzysztof Parzyszek via flang-commits flang-commits at lists.llvm.org
Tue Mar 26 13:04:58 PDT 2024


https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/86726

There were several functions, mostly reduction-related, that were only called from OpenMP.cpp. Remove them from OpenMP.h, and make them local in OpenMP.cpp:
- genOpenMPReduction
- findReductionChain
- getConvertFromReductionOp
- updateReduction
- removeStoreOp

Also, move the function bodies out of the "public" section.

>From edd931dabfdf283db16f2fb2510c51f56127d9bb Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Tue, 26 Mar 2024 13:39:49 -0500
Subject: [PATCH] [flang][OpenMP] Make several function local to OpenMP.cpp,
 NFC

There were several functions, mostly reduction-related, that were only called
from OpenMP.cpp. Remove them from OpenMP.h, and make them local in OpenMP.cpp:
- genOpenMPReduction
- findReductionChain
- getConvertFromReductionOp
- updateReduction
- removeStoreOp

Also, move the function bodies out of the "public" section.
---
 flang/include/flang/Lower/OpenMP.h |  12 -
 flang/lib/Lower/OpenMP/OpenMP.cpp  | 417 ++++++++++++++---------------
 2 files changed, 207 insertions(+), 222 deletions(-)

diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h
index 3b22a652d1fc1e..6e150ef4e8e82f 100644
--- a/flang/include/flang/Lower/OpenMP.h
+++ b/flang/include/flang/Lower/OpenMP.h
@@ -19,7 +19,6 @@
 #include <utility>
 
 namespace mlir {
-class Value;
 class Operation;
 class Location;
 namespace omp {
@@ -30,7 +29,6 @@ enum class DeclareTargetCaptureClause : uint32_t;
 
 namespace fir {
 class FirOpBuilder;
-class ConvertOp;
 } // namespace fir
 
 namespace Fortran {
@@ -84,16 +82,6 @@ void genOpenMPSymbolProperties(AbstractConverter &converter,
 int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
 void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
 void genDeclareTargetIntGlobal(AbstractConverter &, const pft::Variable &);
-void genOpenMPReduction(AbstractConverter &,
-                        Fortran::semantics::SemanticsContext &,
-                        const Fortran::parser::OmpClauseList &clauseList);
-
-mlir::Operation *findReductionChain(mlir::Value, mlir::Value * = nullptr);
-fir::ConvertOp getConvertFromReductionOp(mlir::Operation *, mlir::Value);
-void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value,
-                     mlir::Value, fir::ConvertOp * = nullptr);
-void removeStoreOp(mlir::Operation *, mlir::Value);
-
 bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);
 bool isOpenMPDeviceDeclareTarget(Fortran::lower::AbstractConverter &,
                                  Fortran::semantics::SemanticsContext &,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 0cf2a8f97040a8..0a728b65afbf06 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -237,6 +237,213 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
   return storeOp;
 }
 
+static mlir::Operation *
+findReductionChain(mlir::Value loadVal, mlir::Value *reductionVal = nullptr) {
+  for (mlir::OpOperand &loadOperand : loadVal.getUses()) {
+    if (mlir::Operation *reductionOp = loadOperand.getOwner()) {
+      if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(reductionOp)) {
+        for (mlir::OpOperand &convertOperand : convertOp.getRes().getUses()) {
+          if (mlir::Operation *reductionOp = convertOperand.getOwner())
+            return reductionOp;
+        }
+      }
+      for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) {
+        if (auto store =
+                mlir::dyn_cast<fir::StoreOp>(reductionOperand.getOwner())) {
+          if (store.getMemref() == *reductionVal) {
+            store.erase();
+            return reductionOp;
+          }
+        }
+        if (auto assign =
+                mlir::dyn_cast<hlfir::AssignOp>(reductionOperand.getOwner())) {
+          if (assign.getLhs() == *reductionVal) {
+            assign.erase();
+            return reductionOp;
+          }
+        }
+      }
+    }
+  }
+  return nullptr;
+}
+
+// for a logical operator 'op' reduction X = X op Y
+// This function returns the operation responsible for converting Y from
+// fir.logical<4> to i1
+static fir::ConvertOp getConvertFromReductionOp(mlir::Operation *reductionOp,
+                                                mlir::Value loadVal) {
+  for (mlir::Value reductionOperand : reductionOp->getOperands()) {
+    if (auto convertOp =
+            mlir::dyn_cast<fir::ConvertOp>(reductionOperand.getDefiningOp())) {
+      if (convertOp.getOperand() == loadVal)
+        continue;
+      return convertOp;
+    }
+  }
+  return nullptr;
+}
+
+static void updateReduction(mlir::Operation *op,
+                            fir::FirOpBuilder &firOpBuilder,
+                            mlir::Value loadVal, mlir::Value reductionVal,
+                            fir::ConvertOp *convertOp = nullptr) {
+  mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint();
+  firOpBuilder.setInsertionPoint(op);
+
+  mlir::Value reductionOp;
+  if (convertOp)
+    reductionOp = convertOp->getOperand();
+  else if (op->getOperand(0) == loadVal)
+    reductionOp = op->getOperand(1);
+  else
+    reductionOp = op->getOperand(0);
+
+  firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), reductionOp,
+                                              reductionVal);
+  firOpBuilder.restoreInsertionPoint(insertPtDel);
+}
+
+static void removeStoreOp(mlir::Operation *reductionOp, mlir::Value symVal) {
+  for (mlir::Operation *reductionOpUse : reductionOp->getUsers()) {
+    if (auto convertReduction =
+            mlir::dyn_cast<fir::ConvertOp>(reductionOpUse)) {
+      for (mlir::Operation *convertReductionUse :
+           convertReduction.getRes().getUsers()) {
+        if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(convertReductionUse)) {
+          if (storeOp.getMemref() == symVal)
+            storeOp.erase();
+        }
+        if (auto assignOp =
+                mlir::dyn_cast<hlfir::AssignOp>(convertReductionUse)) {
+          if (assignOp.getLhs() == symVal)
+            assignOp.erase();
+        }
+      }
+    }
+  }
+}
+
+// Generate an OpenMP reduction operation.
+// TODO: Currently assumes it is either an integer addition/multiplication
+// reduction, or a logical and reduction. Generalize this for various reduction
+// operation types.
+// TODO: Generate the reduction operation during lowering instead of creating
+// and removing operations since this is not a robust approach. Also, removing
+// ops in the builder (instead of a rewriter) is probably not the best approach.
+static void genOpenMPReduction(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::semantics::SemanticsContext &semaCtx,
+    const Fortran::parser::OmpClauseList &clauseList) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+  List<Clause> clauses{makeList(clauseList, semaCtx)};
+
+  for (const Clause &clause : clauses) {
+    if (const auto &reductionClause =
+            std::get_if<clause::Reduction>(&clause.u)) {
+      const auto &redOperatorList{
+          std::get<clause::Reduction::ReductionIdentifiers>(
+              reductionClause->t)};
+      assert(redOperatorList.size() == 1 && "Expecting single operator");
+      const auto &redOperator = redOperatorList.front();
+      const auto &objects{std::get<ObjectList>(reductionClause->t)};
+      if (const auto *reductionOp =
+              std::get_if<clause::DefinedOperator>(&redOperator.u)) {
+        const auto &intrinsicOp{
+            std::get<clause::DefinedOperator::IntrinsicOperator>(
+                reductionOp->u)};
+
+        switch (intrinsicOp) {
+        case clause::DefinedOperator::IntrinsicOperator::Add:
+        case clause::DefinedOperator::IntrinsicOperator::Multiply:
+        case clause::DefinedOperator::IntrinsicOperator::AND:
+        case clause::DefinedOperator::IntrinsicOperator::EQV:
+        case clause::DefinedOperator::IntrinsicOperator::OR:
+        case clause::DefinedOperator::IntrinsicOperator::NEQV:
+          break;
+        default:
+          continue;
+        }
+        for (const Object &object : objects) {
+          if (const Fortran::semantics::Symbol *symbol = object.id()) {
+            mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+            if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
+              reductionVal = declOp.getBase();
+            mlir::Type reductionType =
+                reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
+            if (!reductionType.isa<fir::LogicalType>()) {
+              if (!reductionType.isIntOrIndexOrFloat())
+                continue;
+            }
+            for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
+              if (auto loadOp =
+                      mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
+                mlir::Value loadVal = loadOp.getRes();
+                if (reductionType.isa<fir::LogicalType>()) {
+                  mlir::Operation *reductionOp = findReductionChain(loadVal);
+                  fir::ConvertOp convertOp =
+                      getConvertFromReductionOp(reductionOp, loadVal);
+                  updateReduction(reductionOp, firOpBuilder, loadVal,
+                                  reductionVal, &convertOp);
+                  removeStoreOp(reductionOp, reductionVal);
+                } else if (mlir::Operation *reductionOp =
+                               findReductionChain(loadVal, &reductionVal)) {
+                  updateReduction(reductionOp, firOpBuilder, loadVal,
+                                  reductionVal);
+                }
+              }
+            }
+          }
+        }
+      } else if (const auto *reductionIntrinsic =
+                     std::get_if<clause::ProcedureDesignator>(&redOperator.u)) {
+        if (!ReductionProcessor::supportedIntrinsicProcReduction(
+                *reductionIntrinsic))
+          continue;
+        ReductionProcessor::ReductionIdentifier redId =
+            ReductionProcessor::getReductionType(*reductionIntrinsic);
+        for (const Object &object : objects) {
+          if (const Fortran::semantics::Symbol *symbol = object.id()) {
+            mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+            if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
+              reductionVal = declOp.getBase();
+            for (const mlir::OpOperand &reductionValUse :
+                 reductionVal.getUses()) {
+              if (auto loadOp =
+                      mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
+                mlir::Value loadVal = loadOp.getRes();
+                // Max is lowered as a compare -> select.
+                // Match the pattern here.
+                mlir::Operation *reductionOp =
+                    findReductionChain(loadVal, &reductionVal);
+                if (reductionOp == nullptr)
+                  continue;
+
+                if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
+                    redId == ReductionProcessor::ReductionIdentifier::MIN) {
+                  assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
+                         "Selection Op not found in reduction intrinsic");
+                  mlir::Operation *compareOp =
+                      getCompareFromReductionOp(reductionOp, loadVal);
+                  updateReduction(compareOp, firOpBuilder, loadVal,
+                                  reductionVal);
+                }
+                if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
+                    redId == ReductionProcessor::ReductionIdentifier::IEOR ||
+                    redId == ReductionProcessor::ReductionIdentifier::IAND) {
+                  updateReduction(reductionOp, firOpBuilder, loadVal,
+                                  reductionVal);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+}
+
 struct OpWithBodyGenInfo {
   /// A type for a code-gen callback function. This takes as argument the op for
   /// which the code is being generated and returns the arguments of the op's
@@ -2339,216 +2546,6 @@ void Fortran::lower::genDeclareTargetIntGlobal(
   }
 }
 
-// Generate an OpenMP reduction operation.
-// TODO: Currently assumes it is either an integer addition/multiplication
-// reduction, or a logical and reduction. Generalize this for various reduction
-// operation types.
-// TODO: Generate the reduction operation during lowering instead of creating
-// and removing operations since this is not a robust approach. Also, removing
-// ops in the builder (instead of a rewriter) is probably not the best approach.
-void Fortran::lower::genOpenMPReduction(
-    Fortran::lower::AbstractConverter &converter,
-    Fortran::semantics::SemanticsContext &semaCtx,
-    const Fortran::parser::OmpClauseList &clauseList) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
-  List<Clause> clauses{makeList(clauseList, semaCtx)};
-
-  for (const Clause &clause : clauses) {
-    if (const auto &reductionClause =
-            std::get_if<clause::Reduction>(&clause.u)) {
-      const auto &redOperatorList{
-          std::get<clause::Reduction::ReductionIdentifiers>(
-              reductionClause->t)};
-      assert(redOperatorList.size() == 1 && "Expecting single operator");
-      const auto &redOperator = redOperatorList.front();
-      const auto &objects{std::get<ObjectList>(reductionClause->t)};
-      if (const auto *reductionOp =
-              std::get_if<clause::DefinedOperator>(&redOperator.u)) {
-        const auto &intrinsicOp{
-            std::get<clause::DefinedOperator::IntrinsicOperator>(
-                reductionOp->u)};
-
-        switch (intrinsicOp) {
-        case clause::DefinedOperator::IntrinsicOperator::Add:
-        case clause::DefinedOperator::IntrinsicOperator::Multiply:
-        case clause::DefinedOperator::IntrinsicOperator::AND:
-        case clause::DefinedOperator::IntrinsicOperator::EQV:
-        case clause::DefinedOperator::IntrinsicOperator::OR:
-        case clause::DefinedOperator::IntrinsicOperator::NEQV:
-          break;
-        default:
-          continue;
-        }
-        for (const Object &object : objects) {
-          if (const Fortran::semantics::Symbol *symbol = object.id()) {
-            mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
-            if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
-              reductionVal = declOp.getBase();
-            mlir::Type reductionType =
-                reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
-            if (!reductionType.isa<fir::LogicalType>()) {
-              if (!reductionType.isIntOrIndexOrFloat())
-                continue;
-            }
-            for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
-              if (auto loadOp =
-                      mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
-                mlir::Value loadVal = loadOp.getRes();
-                if (reductionType.isa<fir::LogicalType>()) {
-                  mlir::Operation *reductionOp = findReductionChain(loadVal);
-                  fir::ConvertOp convertOp =
-                      getConvertFromReductionOp(reductionOp, loadVal);
-                  updateReduction(reductionOp, firOpBuilder, loadVal,
-                                  reductionVal, &convertOp);
-                  removeStoreOp(reductionOp, reductionVal);
-                } else if (mlir::Operation *reductionOp =
-                               findReductionChain(loadVal, &reductionVal)) {
-                  updateReduction(reductionOp, firOpBuilder, loadVal,
-                                  reductionVal);
-                }
-              }
-            }
-          }
-        }
-      } else if (const auto *reductionIntrinsic =
-                     std::get_if<clause::ProcedureDesignator>(&redOperator.u)) {
-        if (!ReductionProcessor::supportedIntrinsicProcReduction(
-                *reductionIntrinsic))
-          continue;
-        ReductionProcessor::ReductionIdentifier redId =
-            ReductionProcessor::getReductionType(*reductionIntrinsic);
-        for (const Object &object : objects) {
-          if (const Fortran::semantics::Symbol *symbol = object.id()) {
-            mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
-            if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
-              reductionVal = declOp.getBase();
-            for (const mlir::OpOperand &reductionValUse :
-                 reductionVal.getUses()) {
-              if (auto loadOp =
-                      mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
-                mlir::Value loadVal = loadOp.getRes();
-                // Max is lowered as a compare -> select.
-                // Match the pattern here.
-                mlir::Operation *reductionOp =
-                    findReductionChain(loadVal, &reductionVal);
-                if (reductionOp == nullptr)
-                  continue;
-
-                if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
-                    redId == ReductionProcessor::ReductionIdentifier::MIN) {
-                  assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
-                         "Selection Op not found in reduction intrinsic");
-                  mlir::Operation *compareOp =
-                      getCompareFromReductionOp(reductionOp, loadVal);
-                  updateReduction(compareOp, firOpBuilder, loadVal,
-                                  reductionVal);
-                }
-                if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
-                    redId == ReductionProcessor::ReductionIdentifier::IEOR ||
-                    redId == ReductionProcessor::ReductionIdentifier::IAND) {
-                  updateReduction(reductionOp, firOpBuilder, loadVal,
-                                  reductionVal);
-                }
-              }
-            }
-          }
-        }
-      }
-    }
-  }
-}
-
-mlir::Operation *Fortran::lower::findReductionChain(mlir::Value loadVal,
-                                                    mlir::Value *reductionVal) {
-  for (mlir::OpOperand &loadOperand : loadVal.getUses()) {
-    if (mlir::Operation *reductionOp = loadOperand.getOwner()) {
-      if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(reductionOp)) {
-        for (mlir::OpOperand &convertOperand : convertOp.getRes().getUses()) {
-          if (mlir::Operation *reductionOp = convertOperand.getOwner())
-            return reductionOp;
-        }
-      }
-      for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) {
-        if (auto store =
-                mlir::dyn_cast<fir::StoreOp>(reductionOperand.getOwner())) {
-          if (store.getMemref() == *reductionVal) {
-            store.erase();
-            return reductionOp;
-          }
-        }
-        if (auto assign =
-                mlir::dyn_cast<hlfir::AssignOp>(reductionOperand.getOwner())) {
-          if (assign.getLhs() == *reductionVal) {
-            assign.erase();
-            return reductionOp;
-          }
-        }
-      }
-    }
-  }
-  return nullptr;
-}
-
-// for a logical operator 'op' reduction X = X op Y
-// This function returns the operation responsible for converting Y from
-// fir.logical<4> to i1
-fir::ConvertOp
-Fortran::lower::getConvertFromReductionOp(mlir::Operation *reductionOp,
-                                          mlir::Value loadVal) {
-  for (mlir::Value reductionOperand : reductionOp->getOperands()) {
-    if (auto convertOp =
-            mlir::dyn_cast<fir::ConvertOp>(reductionOperand.getDefiningOp())) {
-      if (convertOp.getOperand() == loadVal)
-        continue;
-      return convertOp;
-    }
-  }
-  return nullptr;
-}
-
-void Fortran::lower::updateReduction(mlir::Operation *op,
-                                     fir::FirOpBuilder &firOpBuilder,
-                                     mlir::Value loadVal,
-                                     mlir::Value reductionVal,
-                                     fir::ConvertOp *convertOp) {
-  mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint();
-  firOpBuilder.setInsertionPoint(op);
-
-  mlir::Value reductionOp;
-  if (convertOp)
-    reductionOp = convertOp->getOperand();
-  else if (op->getOperand(0) == loadVal)
-    reductionOp = op->getOperand(1);
-  else
-    reductionOp = op->getOperand(0);
-
-  firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), reductionOp,
-                                              reductionVal);
-  firOpBuilder.restoreInsertionPoint(insertPtDel);
-}
-
-void Fortran::lower::removeStoreOp(mlir::Operation *reductionOp,
-                                   mlir::Value symVal) {
-  for (mlir::Operation *reductionOpUse : reductionOp->getUsers()) {
-    if (auto convertReduction =
-            mlir::dyn_cast<fir::ConvertOp>(reductionOpUse)) {
-      for (mlir::Operation *convertReductionUse :
-           convertReduction.getRes().getUsers()) {
-        if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(convertReductionUse)) {
-          if (storeOp.getMemref() == symVal)
-            storeOp.erase();
-        }
-        if (auto assignOp =
-                mlir::dyn_cast<hlfir::AssignOp>(convertReductionUse)) {
-          if (assignOp.getLhs() == symVal)
-            assignOp.erase();
-        }
-      }
-    }
-  }
-}
-
 bool Fortran::lower::isOpenMPTargetConstruct(
     const Fortran::parser::OpenMPConstruct &omp) {
   llvm::omp::Directive dir = llvm::omp::Directive::OMPD_unknown;



More information about the flang-commits mailing list