[flang-commits] [flang] [Flang][OpenMP] NFC: Remove old reduction lowering code (PR #88798)
Kiran Chandramohan via flang-commits
flang-commits at lists.llvm.org
Mon Apr 15 14:06:46 PDT 2024
https://github.com/kiranchandramohan updated https://github.com/llvm/llvm-project/pull/88798
>From 998f5089d60539c9447da23b597b7de68aa998e9 Mon Sep 17 00:00:00 2001
From: Kiran Chandramohan <kiran.chandramohan at arm.com>
Date: Mon, 15 Apr 2024 21:01:25 +0000
Subject: [PATCH] [Flang][OpenMP] NFC: Remove old reduction lowering code
---
flang/lib/Lower/OpenMP/OpenMP.cpp | 223 ------------------------------
1 file changed, 223 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 340921c867246c..b67ee59d463eb3 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -103,21 +103,6 @@ static fir::GlobalOp globalInitialization(
return global;
}
-static mlir::Operation *getCompareFromReductionOp(mlir::Operation *reductionOp,
- mlir::Value loadVal) {
- for (mlir::Value reductionOperand : reductionOp->getOperands()) {
- if (mlir::Operation *compareOp = reductionOperand.getDefiningOp()) {
- if (compareOp->getOperand(0) == loadVal ||
- compareOp->getOperand(1) == loadVal)
- assert((mlir::isa<mlir::arith::CmpIOp>(compareOp) ||
- mlir::isa<mlir::arith::CmpFOp>(compareOp)) &&
- "Expected comparison not found in reduction intrinsic");
- return compareOp;
- }
- }
- return nullptr;
-}
-
// Get the extended value for \p val by extracting additional variable
// information from \p base.
static fir::ExtendedValue getExtendedValue(fir::ExtendedValue base,
@@ -237,213 +222,6 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
return storeOp;
}
-static mlir::Operation *
-findReductionChain(mlir::Value loadVal, mlir::Value *reductionVal = nullptr) {
- for (mlir::OpOperand &loadOperand : loadVal.getUses()) {
- if (mlir::Operation *reductionOp = loadOperand.getOwner()) {
- if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(reductionOp)) {
- for (mlir::OpOperand &convertOperand : convertOp.getRes().getUses()) {
- if (mlir::Operation *reductionOp = convertOperand.getOwner())
- return reductionOp;
- }
- }
- for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) {
- if (auto store =
- mlir::dyn_cast<fir::StoreOp>(reductionOperand.getOwner())) {
- if (store.getMemref() == *reductionVal) {
- store.erase();
- return reductionOp;
- }
- }
- if (auto assign =
- mlir::dyn_cast<hlfir::AssignOp>(reductionOperand.getOwner())) {
- if (assign.getLhs() == *reductionVal) {
- assign.erase();
- return reductionOp;
- }
- }
- }
- }
- }
- return nullptr;
-}
-
-// for a logical operator 'op' reduction X = X op Y
-// This function returns the operation responsible for converting Y from
-// fir.logical<4> to i1
-static fir::ConvertOp getConvertFromReductionOp(mlir::Operation *reductionOp,
- mlir::Value loadVal) {
- for (mlir::Value reductionOperand : reductionOp->getOperands()) {
- if (auto convertOp =
- mlir::dyn_cast<fir::ConvertOp>(reductionOperand.getDefiningOp())) {
- if (convertOp.getOperand() == loadVal)
- continue;
- return convertOp;
- }
- }
- return nullptr;
-}
-
-static void updateReduction(mlir::Operation *op,
- fir::FirOpBuilder &firOpBuilder,
- mlir::Value loadVal, mlir::Value reductionVal,
- fir::ConvertOp *convertOp = nullptr) {
- mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint();
- firOpBuilder.setInsertionPoint(op);
-
- mlir::Value reductionOp;
- if (convertOp)
- reductionOp = convertOp->getOperand();
- else if (op->getOperand(0) == loadVal)
- reductionOp = op->getOperand(1);
- else
- reductionOp = op->getOperand(0);
-
- firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), reductionOp,
- reductionVal);
- firOpBuilder.restoreInsertionPoint(insertPtDel);
-}
-
-static void removeStoreOp(mlir::Operation *reductionOp, mlir::Value symVal) {
- for (mlir::Operation *reductionOpUse : reductionOp->getUsers()) {
- if (auto convertReduction =
- mlir::dyn_cast<fir::ConvertOp>(reductionOpUse)) {
- for (mlir::Operation *convertReductionUse :
- convertReduction.getRes().getUsers()) {
- if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(convertReductionUse)) {
- if (storeOp.getMemref() == symVal)
- storeOp.erase();
- }
- if (auto assignOp =
- mlir::dyn_cast<hlfir::AssignOp>(convertReductionUse)) {
- if (assignOp.getLhs() == symVal)
- assignOp.erase();
- }
- }
- }
- }
-}
-
-// Generate an OpenMP reduction operation.
-// TODO: Currently assumes it is either an integer addition/multiplication
-// reduction, or a logical and reduction. Generalize this for various reduction
-// operation types.
-// TODO: Generate the reduction operation during lowering instead of creating
-// and removing operations since this is not a robust approach. Also, removing
-// ops in the builder (instead of a rewriter) is probably not the best approach.
-static void
-genOpenMPReduction(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauseList) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
- List<Clause> clauses{makeClauses(clauseList, semaCtx)};
-
- for (const Clause &clause : clauses) {
- if (const auto &reductionClause =
- std::get_if<clause::Reduction>(&clause.u)) {
- const auto &redOperatorList{
- std::get<clause::Reduction::ReductionIdentifiers>(
- reductionClause->t)};
- assert(redOperatorList.size() == 1 && "Expecting single operator");
- const auto &redOperator = redOperatorList.front();
- const auto &objects{std::get<ObjectList>(reductionClause->t)};
- if (const auto *reductionOp =
- std::get_if<clause::DefinedOperator>(&redOperator.u)) {
- const auto &intrinsicOp{
- std::get<clause::DefinedOperator::IntrinsicOperator>(
- reductionOp->u)};
-
- switch (intrinsicOp) {
- case clause::DefinedOperator::IntrinsicOperator::Add:
- case clause::DefinedOperator::IntrinsicOperator::Multiply:
- case clause::DefinedOperator::IntrinsicOperator::AND:
- case clause::DefinedOperator::IntrinsicOperator::EQV:
- case clause::DefinedOperator::IntrinsicOperator::OR:
- case clause::DefinedOperator::IntrinsicOperator::NEQV:
- break;
- default:
- continue;
- }
- for (const Object &object : objects) {
- if (const Fortran::semantics::Symbol *symbol = object.id()) {
- mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
- reductionVal = declOp.getBase();
- mlir::Type reductionType =
- reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
- if (!reductionType.isa<fir::LogicalType>()) {
- if (!reductionType.isIntOrIndexOrFloat())
- continue;
- }
- for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
- if (auto loadOp =
- mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
- mlir::Value loadVal = loadOp.getRes();
- if (reductionType.isa<fir::LogicalType>()) {
- mlir::Operation *reductionOp = findReductionChain(loadVal);
- fir::ConvertOp convertOp =
- getConvertFromReductionOp(reductionOp, loadVal);
- updateReduction(reductionOp, firOpBuilder, loadVal,
- reductionVal, &convertOp);
- removeStoreOp(reductionOp, reductionVal);
- } else if (mlir::Operation *reductionOp =
- findReductionChain(loadVal, &reductionVal)) {
- updateReduction(reductionOp, firOpBuilder, loadVal,
- reductionVal);
- }
- }
- }
- }
- }
- } else if (const auto *reductionIntrinsic =
- std::get_if<clause::ProcedureDesignator>(&redOperator.u)) {
- if (!ReductionProcessor::supportedIntrinsicProcReduction(
- *reductionIntrinsic))
- continue;
- ReductionProcessor::ReductionIdentifier redId =
- ReductionProcessor::getReductionType(*reductionIntrinsic);
- for (const Object &object : objects) {
- if (const Fortran::semantics::Symbol *symbol = object.id()) {
- mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
- reductionVal = declOp.getBase();
- for (const mlir::OpOperand &reductionValUse :
- reductionVal.getUses()) {
- if (auto loadOp =
- mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
- mlir::Value loadVal = loadOp.getRes();
- // Max is lowered as a compare -> select.
- // Match the pattern here.
- mlir::Operation *reductionOp =
- findReductionChain(loadVal, &reductionVal);
- if (reductionOp == nullptr)
- continue;
-
- if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
- redId == ReductionProcessor::ReductionIdentifier::MIN) {
- assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
- "Selection Op not found in reduction intrinsic");
- mlir::Operation *compareOp =
- getCompareFromReductionOp(reductionOp, loadVal);
- updateReduction(compareOp, firOpBuilder, loadVal,
- reductionVal);
- }
- if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
- redId == ReductionProcessor::ReductionIdentifier::IEOR ||
- redId == ReductionProcessor::ReductionIdentifier::IAND) {
- updateReduction(reductionOp, firOpBuilder, loadVal,
- reductionVal);
- }
- }
- }
- }
- }
- }
- }
- }
-}
-
struct OpWithBodyGenInfo {
/// A type for a code-gen callback function. This takes as argument the op for
/// which the code is being generated and returns the arguments of the op's
@@ -2287,7 +2065,6 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
// 2.9.3.1 SIMD construct
createSimdLoop(converter, semaCtx, eval, ompDirective, loopOpClauseList,
currentLocation);
- genOpenMPReduction(converter, semaCtx, loopOpClauseList);
} else {
createWsloop(converter, semaCtx, eval, ompDirective, loopOpClauseList,
endClauseList, currentLocation);
More information about the flang-commits
mailing list