[flang-commits] [flang] [Flang][OpenMP] NFC: Remove old reduction lowering code (PR #88798)

Kiran Chandramohan via flang-commits flang-commits at lists.llvm.org
Mon Apr 15 14:06:46 PDT 2024


https://github.com/kiranchandramohan updated https://github.com/llvm/llvm-project/pull/88798

>From 998f5089d60539c9447da23b597b7de68aa998e9 Mon Sep 17 00:00:00 2001
From: Kiran Chandramohan <kiran.chandramohan at arm.com>
Date: Mon, 15 Apr 2024 21:01:25 +0000
Subject: [PATCH] [Flang][OpenMP] NFC: Remove old reduction lowering code

---
 flang/lib/Lower/OpenMP/OpenMP.cpp | 223 ------------------------------
 1 file changed, 223 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 340921c867246c..b67ee59d463eb3 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -103,21 +103,6 @@ static fir::GlobalOp globalInitialization(
   return global;
 }
 
-static mlir::Operation *getCompareFromReductionOp(mlir::Operation *reductionOp,
-                                                  mlir::Value loadVal) {
-  for (mlir::Value reductionOperand : reductionOp->getOperands()) {
-    if (mlir::Operation *compareOp = reductionOperand.getDefiningOp()) {
-      if (compareOp->getOperand(0) == loadVal ||
-          compareOp->getOperand(1) == loadVal)
-        assert((mlir::isa<mlir::arith::CmpIOp>(compareOp) ||
-                mlir::isa<mlir::arith::CmpFOp>(compareOp)) &&
-               "Expected comparison not found in reduction intrinsic");
-      return compareOp;
-    }
-  }
-  return nullptr;
-}
-
 // Get the extended value for \p val by extracting additional variable
 // information from \p base.
 static fir::ExtendedValue getExtendedValue(fir::ExtendedValue base,
@@ -237,213 +222,6 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
   return storeOp;
 }
 
-static mlir::Operation *
-findReductionChain(mlir::Value loadVal, mlir::Value *reductionVal = nullptr) {
-  for (mlir::OpOperand &loadOperand : loadVal.getUses()) {
-    if (mlir::Operation *reductionOp = loadOperand.getOwner()) {
-      if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(reductionOp)) {
-        for (mlir::OpOperand &convertOperand : convertOp.getRes().getUses()) {
-          if (mlir::Operation *reductionOp = convertOperand.getOwner())
-            return reductionOp;
-        }
-      }
-      for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) {
-        if (auto store =
-                mlir::dyn_cast<fir::StoreOp>(reductionOperand.getOwner())) {
-          if (store.getMemref() == *reductionVal) {
-            store.erase();
-            return reductionOp;
-          }
-        }
-        if (auto assign =
-                mlir::dyn_cast<hlfir::AssignOp>(reductionOperand.getOwner())) {
-          if (assign.getLhs() == *reductionVal) {
-            assign.erase();
-            return reductionOp;
-          }
-        }
-      }
-    }
-  }
-  return nullptr;
-}
-
-// for a logical operator 'op' reduction X = X op Y
-// This function returns the operation responsible for converting Y from
-// fir.logical<4> to i1
-static fir::ConvertOp getConvertFromReductionOp(mlir::Operation *reductionOp,
-                                                mlir::Value loadVal) {
-  for (mlir::Value reductionOperand : reductionOp->getOperands()) {
-    if (auto convertOp =
-            mlir::dyn_cast<fir::ConvertOp>(reductionOperand.getDefiningOp())) {
-      if (convertOp.getOperand() == loadVal)
-        continue;
-      return convertOp;
-    }
-  }
-  return nullptr;
-}
-
-static void updateReduction(mlir::Operation *op,
-                            fir::FirOpBuilder &firOpBuilder,
-                            mlir::Value loadVal, mlir::Value reductionVal,
-                            fir::ConvertOp *convertOp = nullptr) {
-  mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint();
-  firOpBuilder.setInsertionPoint(op);
-
-  mlir::Value reductionOp;
-  if (convertOp)
-    reductionOp = convertOp->getOperand();
-  else if (op->getOperand(0) == loadVal)
-    reductionOp = op->getOperand(1);
-  else
-    reductionOp = op->getOperand(0);
-
-  firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), reductionOp,
-                                              reductionVal);
-  firOpBuilder.restoreInsertionPoint(insertPtDel);
-}
-
-static void removeStoreOp(mlir::Operation *reductionOp, mlir::Value symVal) {
-  for (mlir::Operation *reductionOpUse : reductionOp->getUsers()) {
-    if (auto convertReduction =
-            mlir::dyn_cast<fir::ConvertOp>(reductionOpUse)) {
-      for (mlir::Operation *convertReductionUse :
-           convertReduction.getRes().getUsers()) {
-        if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(convertReductionUse)) {
-          if (storeOp.getMemref() == symVal)
-            storeOp.erase();
-        }
-        if (auto assignOp =
-                mlir::dyn_cast<hlfir::AssignOp>(convertReductionUse)) {
-          if (assignOp.getLhs() == symVal)
-            assignOp.erase();
-        }
-      }
-    }
-  }
-}
-
-// Generate an OpenMP reduction operation.
-// TODO: Currently assumes it is either an integer addition/multiplication
-// reduction, or a logical and reduction. Generalize this for various reduction
-// operation types.
-// TODO: Generate the reduction operation during lowering instead of creating
-// and removing operations since this is not a robust approach. Also, removing
-// ops in the builder (instead of a rewriter) is probably not the best approach.
-static void
-genOpenMPReduction(Fortran::lower::AbstractConverter &converter,
-                   Fortran::semantics::SemanticsContext &semaCtx,
-                   const Fortran::parser::OmpClauseList &clauseList) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
-  List<Clause> clauses{makeClauses(clauseList, semaCtx)};
-
-  for (const Clause &clause : clauses) {
-    if (const auto &reductionClause =
-            std::get_if<clause::Reduction>(&clause.u)) {
-      const auto &redOperatorList{
-          std::get<clause::Reduction::ReductionIdentifiers>(
-              reductionClause->t)};
-      assert(redOperatorList.size() == 1 && "Expecting single operator");
-      const auto &redOperator = redOperatorList.front();
-      const auto &objects{std::get<ObjectList>(reductionClause->t)};
-      if (const auto *reductionOp =
-              std::get_if<clause::DefinedOperator>(&redOperator.u)) {
-        const auto &intrinsicOp{
-            std::get<clause::DefinedOperator::IntrinsicOperator>(
-                reductionOp->u)};
-
-        switch (intrinsicOp) {
-        case clause::DefinedOperator::IntrinsicOperator::Add:
-        case clause::DefinedOperator::IntrinsicOperator::Multiply:
-        case clause::DefinedOperator::IntrinsicOperator::AND:
-        case clause::DefinedOperator::IntrinsicOperator::EQV:
-        case clause::DefinedOperator::IntrinsicOperator::OR:
-        case clause::DefinedOperator::IntrinsicOperator::NEQV:
-          break;
-        default:
-          continue;
-        }
-        for (const Object &object : objects) {
-          if (const Fortran::semantics::Symbol *symbol = object.id()) {
-            mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
-            if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
-              reductionVal = declOp.getBase();
-            mlir::Type reductionType =
-                reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
-            if (!reductionType.isa<fir::LogicalType>()) {
-              if (!reductionType.isIntOrIndexOrFloat())
-                continue;
-            }
-            for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
-              if (auto loadOp =
-                      mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
-                mlir::Value loadVal = loadOp.getRes();
-                if (reductionType.isa<fir::LogicalType>()) {
-                  mlir::Operation *reductionOp = findReductionChain(loadVal);
-                  fir::ConvertOp convertOp =
-                      getConvertFromReductionOp(reductionOp, loadVal);
-                  updateReduction(reductionOp, firOpBuilder, loadVal,
-                                  reductionVal, &convertOp);
-                  removeStoreOp(reductionOp, reductionVal);
-                } else if (mlir::Operation *reductionOp =
-                               findReductionChain(loadVal, &reductionVal)) {
-                  updateReduction(reductionOp, firOpBuilder, loadVal,
-                                  reductionVal);
-                }
-              }
-            }
-          }
-        }
-      } else if (const auto *reductionIntrinsic =
-                     std::get_if<clause::ProcedureDesignator>(&redOperator.u)) {
-        if (!ReductionProcessor::supportedIntrinsicProcReduction(
-                *reductionIntrinsic))
-          continue;
-        ReductionProcessor::ReductionIdentifier redId =
-            ReductionProcessor::getReductionType(*reductionIntrinsic);
-        for (const Object &object : objects) {
-          if (const Fortran::semantics::Symbol *symbol = object.id()) {
-            mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
-            if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
-              reductionVal = declOp.getBase();
-            for (const mlir::OpOperand &reductionValUse :
-                 reductionVal.getUses()) {
-              if (auto loadOp =
-                      mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
-                mlir::Value loadVal = loadOp.getRes();
-                // Max is lowered as a compare -> select.
-                // Match the pattern here.
-                mlir::Operation *reductionOp =
-                    findReductionChain(loadVal, &reductionVal);
-                if (reductionOp == nullptr)
-                  continue;
-
-                if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
-                    redId == ReductionProcessor::ReductionIdentifier::MIN) {
-                  assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
-                         "Selection Op not found in reduction intrinsic");
-                  mlir::Operation *compareOp =
-                      getCompareFromReductionOp(reductionOp, loadVal);
-                  updateReduction(compareOp, firOpBuilder, loadVal,
-                                  reductionVal);
-                }
-                if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
-                    redId == ReductionProcessor::ReductionIdentifier::IEOR ||
-                    redId == ReductionProcessor::ReductionIdentifier::IAND) {
-                  updateReduction(reductionOp, firOpBuilder, loadVal,
-                                  reductionVal);
-                }
-              }
-            }
-          }
-        }
-      }
-    }
-  }
-}
-
 struct OpWithBodyGenInfo {
   /// A type for a code-gen callback function. This takes as argument the op for
   /// which the code is being generated and returns the arguments of the op's
@@ -2287,7 +2065,6 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     // 2.9.3.1 SIMD construct
     createSimdLoop(converter, semaCtx, eval, ompDirective, loopOpClauseList,
                    currentLocation);
-    genOpenMPReduction(converter, semaCtx, loopOpClauseList);
   } else {
     createWsloop(converter, semaCtx, eval, ompDirective, loopOpClauseList,
                  endClauseList, currentLocation);



More information about the flang-commits mailing list