[flang-commits] [flang] [flang][OpenMP] Common lowering flow for atomic update (PR #69866)
via flang-commits
flang-commits at lists.llvm.org
Sat Oct 21 21:34:41 PDT 2023
https://github.com/NimishMishra created https://github.com/llvm/llvm-project/pull/69866
Offers a common lowering flow for scalar/non-scalar atomic variables. Fixes https://github.com/llvm/llvm-project/issues/68384
TODOs:
1. Lower intrinsic procedures
2. Use correct operations for for AND, OR, EQV, NEQV
3. Discuss whether multiply/divide are signed or unsigned
>From 58baea0fa0e349e08c51a56b5f0ca95ec42a8d7f Mon Sep 17 00:00:00 2001
From: Nimish Mishra <neelam.nimish at gmail.com>
Date: Sun, 22 Oct 2023 09:59:09 +0530
Subject: [PATCH] [flang][OpenMP] Common lowering flow for atomic update
---
flang/lib/Lower/DirectivesCommon.h | 188 ++++++++----------
.../Lower/OpenMP/common-atomic-lowering.f90 | 74 +++++++
2 files changed, 160 insertions(+), 102 deletions(-)
create mode 100644 flang/test/Lower/OpenMP/common-atomic-lowering.f90
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index ed44598bc925212..56d79de2d31995d 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -204,76 +204,62 @@ static inline void genOmpAccAtomicUpdateStatement(
// Generate `omp.atomic.update` operation for atomic assignment statements
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Location currentLocation = converter.getCurrentLocation();
-
- const auto *varDesignator =
- std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>(
- &assignmentStmtVariable.u);
- assert(varDesignator && "Variable designator for atomic update assignment "
- "statement does not exist");
- const Fortran::parser::Name *name =
- Fortran::semantics::getDesignatorNameIfDataRef(varDesignator->value());
- if (!name)
- TODO(converter.getCurrentLocation(),
- "Array references as atomic update variable");
- assert(name && name->symbol &&
- "No symbol attached to atomic update variable");
- if (Fortran::semantics::IsAllocatableOrPointer(name->symbol->GetUltimate()))
- converter.bindSymbol(*name->symbol, lhsAddr);
-
- // Lowering is in two steps :
- // subroutine sb
- // integer :: a, b
- // !$omp atomic update
- // a = a + b
- // end subroutine
- //
- // 1. Lower to scf.execute_region_op
- //
- // func.func @_QPsb() {
- // %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
- // %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
- // %2 = scf.execute_region -> i32 {
- // %3 = fir.load %0 : !fir.ref<i32>
- // %4 = fir.load %1 : !fir.ref<i32>
- // %5 = arith.addi %3, %4 : i32
- // scf.yield %5 : i32
- // }
- // return
- // }
- auto tempOp =
- firOpBuilder.create<mlir::scf::ExecuteRegionOp>(currentLocation, varType);
- firOpBuilder.createBlock(&tempOp.getRegion());
- mlir::Block &block = tempOp.getRegion().back();
- firOpBuilder.setInsertionPointToEnd(&block);
Fortran::lower::StatementContext stmtCtx;
- mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx));
- mlir::Value convertResult =
- firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
- // Insert the terminator: YieldOp.
- firOpBuilder.create<mlir::scf::YieldOp>(currentLocation, convertResult);
- firOpBuilder.setInsertionPointToStart(&block);
-
- // 2. Create the omp.atomic.update Operation using the Operations in the
- // temporary scf.execute_region Operation.
- //
- // func.func @_QPsb() {
- // %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
- // %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
- // %2 = fir.load %1 : !fir.ref<i32>
- // omp.atomic.update %0 : !fir.ref<i32> {
- // ^bb0(%arg0: i32):
- // %3 = fir.load %1 : !fir.ref<i32>
- // %4 = arith.addi %arg0, %3 : i32
- // omp.yield(%3 : i32)
- // }
- // return
- // }
- mlir::Value updateVar = converter.getSymbolAddress(*name->symbol);
- if (auto decl = updateVar.getDefiningOp<hlfir::DeclareOp>())
- updateVar = decl.getBase();
-
- firOpBuilder.setInsertionPointAfter(tempOp);
+ mlir::Value convertRhs = nullptr;
+
+ auto lowerExpression = [&](const auto &intrinsicBinaryExpr) {
+ const auto &variableName{assignmentStmtVariable.GetSource().ToString()};
+ const auto &exprLeft{std::get<0>(intrinsicBinaryExpr.t)};
+ if (exprLeft.value().source.ToString() == variableName) {
+ // Update statement is of form `x = x op expr`
+ const auto &exprToLower{std::get<1>(intrinsicBinaryExpr.t)};
+ mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(exprToLower), stmtCtx));
+ convertRhs =
+ firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
+ } else {
+ // Update statement is of form `x = expr op x`
+ const auto &exprToLower{std::get<0>(intrinsicBinaryExpr.t)};
+ mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(exprToLower), stmtCtx));
+ convertRhs =
+ firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
+ }
+ };
+ Fortran::common::visit(
+ Fortran::common::visitors{
+ [&](const common::Indirection<parser::FunctionReference> &x) {
+ TODO(converter.getCurrentLocation(),
+ "Not yet implemented: intrinsic procedure in atomic update "
+ "expressions");
+ },
+ [&](const Fortran::parser::Expr::Add &intrinsicBinaryExpr) {
+ lowerExpression(intrinsicBinaryExpr);
+ },
+ [&](const Fortran::parser::Expr::Subtract &intrinsicBinaryExpr) {
+ lowerExpression(intrinsicBinaryExpr);
+ },
+ [&](const Fortran::parser::Expr::Multiply &intrinsicBinaryExpr) {
+ lowerExpression(intrinsicBinaryExpr);
+ },
+ [&](const Fortran::parser::Expr::Divide &intrinsicBinaryExpr) {
+ lowerExpression(intrinsicBinaryExpr);
+ },
+ [&](const Fortran::parser::Expr::AND &intrinsicBinaryExpr) {
+ lowerExpression(intrinsicBinaryExpr);
+ },
+ [&](const Fortran::parser::Expr::OR &intrinsicBinaryExpr) {
+ lowerExpression(intrinsicBinaryExpr);
+ },
+ [&](const Fortran::parser::Expr::EQV &intrinsicBinaryExpr) {
+ lowerExpression(intrinsicBinaryExpr);
+ },
+ [&](const Fortran::parser::Expr::NEQV &intrinsicBinaryExpr) {
+ lowerExpression(intrinsicBinaryExpr);
+ },
+ [&](const auto &) {},
+ },
+ assignmentStmtExpr.u);
mlir::Operation *atomicUpdateOp = nullptr;
if constexpr (std::is_same<AtomicListT,
@@ -289,10 +275,10 @@ static inline void genOmpAccAtomicUpdateStatement(
genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList,
hint, memoryOrder);
atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>(
- currentLocation, updateVar, hint, memoryOrder);
+ currentLocation, lhsAddr, hint, memoryOrder);
} else {
atomicUpdateOp = firOpBuilder.create<mlir::acc::AtomicUpdateOp>(
- currentLocation, updateVar);
+ currentLocation, lhsAddr);
}
llvm::SmallVector<mlir::Type> varTys = {varType};
@@ -301,38 +287,36 @@ static inline void genOmpAccAtomicUpdateStatement(
mlir::Value val =
fir::getBase(atomicUpdateOp->getRegion(0).front().getArgument(0));
- llvm::SmallVector<mlir::Operation *> ops;
- for (mlir::Operation &op : tempOp.getRegion().getOps())
- ops.push_back(&op);
-
- // SCF Yield is converted to OMP Yield. All other operations are copied
- for (mlir::Operation *op : ops) {
- if (auto y = mlir::dyn_cast<mlir::scf::YieldOp>(op)) {
- firOpBuilder.setInsertionPointToEnd(
- &atomicUpdateOp->getRegion(0).front());
- if constexpr (std::is_same<AtomicListT,
- Fortran::parser::OmpAtomicClauseList>()) {
- firOpBuilder.create<mlir::omp::YieldOp>(currentLocation,
- y.getResults());
- } else {
- firOpBuilder.create<mlir::acc::YieldOp>(currentLocation,
- y.getResults());
- }
- op->erase();
- } else {
- op->remove();
- atomicUpdateOp->getRegion(0).front().push_back(op);
- }
- }
-
- // Remove the load and replace all uses of load with the block argument
- for (mlir::Operation &op : atomicUpdateOp->getRegion(0).getOps()) {
- fir::LoadOp y = mlir::dyn_cast<fir::LoadOp>(&op);
- if (y && y.getMemref() == updateVar)
- y.getRes().replaceAllUsesWith(val);
+ mlir::Value op = nullptr;
+ if (std::get_if<Fortran::parser::Expr::Add>(&assignmentStmtExpr.u)) {
+ op = firOpBuilder.create<mlir::arith::AddIOp>(currentLocation, val,
+ convertRhs);
+ } else if (std::get_if<Fortran::parser::Expr::Subtract>(
+ &assignmentStmtExpr.u)) {
+ op = firOpBuilder.create<mlir::arith::SubIOp>(currentLocation, val,
+ convertRhs);
+ } else if (std::get_if<Fortran::parser::Expr::Multiply>(
+ &assignmentStmtExpr.u)) {
+ op = firOpBuilder.create<mlir::arith::MulIOp>(currentLocation, val,
+ convertRhs);
+ } else if (std::get_if<Fortran::parser::Expr::Divide>(
+ &assignmentStmtExpr.u)) {
+ op = firOpBuilder.create<mlir::arith::DivUIOp>(currentLocation, val,
+ convertRhs);
+ } else if (std::get_if<Fortran::parser::Expr::AND>(&assignmentStmtExpr.u)) {
+ op = firOpBuilder.create<mlir::arith::AndIOp>(currentLocation, val,
+ convertRhs);
+ } else if (std::get_if<Fortran::parser::Expr::OR>(&assignmentStmtExpr.u)) {
+ op = firOpBuilder.create<mlir::arith::OrIOp>(currentLocation, val,
+ convertRhs);
+ } else if (std::get_if<Fortran::parser::Expr::EQV>(&assignmentStmtExpr.u)) {
+ op = firOpBuilder.create<mlir::arith::CmpIOp>(
+ currentLocation, mlir::arith::CmpIPredicate::eq, val, convertRhs);
+ } else if (std::get_if<Fortran::parser::Expr::NEQV>(&assignmentStmtExpr.u)) {
+ op = firOpBuilder.create<mlir::arith::CmpIOp>(
+ currentLocation, mlir::arith::CmpIPredicate::ne, val, convertRhs);
}
-
- tempOp.erase();
+ firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, op);
}
/// Processes an atomic construct with write clause.
diff --git a/flang/test/Lower/OpenMP/common-atomic-lowering.f90 b/flang/test/Lower/OpenMP/common-atomic-lowering.f90
new file mode 100644
index 000000000000000..7da30243e676c00
--- /dev/null
+++ b/flang/test/Lower/OpenMP/common-atomic-lowering.f90
@@ -0,0 +1,74 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+!CHECK: func.func @_QQmain() attributes {fir.bindc_name = "sample"} {
+!CHECK: %[[val_0:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFEa"}
+!CHECK: %[[val_1:.*]]:2 = hlfir.declare %[[val_0]] {uniq_name = "_QFEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[val_2:.*]] = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFEb"}
+!CHECK: %[[val_3:.*]]:2 = hlfir.declare %[[val_2]] {uniq_name = "_QFEb"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[val_4:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
+!CHECK: %[[val_5:.*]]:2 = hlfir.declare %[[val_4]] {uniq_name = "_QFEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[val_c5:.*]] = arith.constant 5 : index
+!CHECK: %[[val_6:.*]] = fir.alloca !fir.array<5xi32> {bindc_name = "y", uniq_name = "_QFEy"}
+!CHECK: %[[val_7:.*]] = fir.shape %[[val_c5]] : (index) -> !fir.shape<1>
+!CHECK: %[[val_8:.*]]:2 = hlfir.declare %[[val_6]](%[[val_7]]) {uniq_name = "_QFEy"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>)
+!CHECK: %[[val_c2:.*]] = arith.constant 2 : index
+!CHECK: %[[val_9:.*]] = hlfir.designate %[[val_8]]#0 (%[[val_c2]]) : (!fir.ref<!fir.array<5xi32>>, index) -> !fir.ref<i32>
+!CHECK: %[[val_c8:.*]] = arith.constant 8 : i32
+!CHECK: %[[val_10:.*]] = fir.load %[[val_5]]#0 : !fir.ref<i32>
+!CHECK: %[[val_11:.*]] = arith.addi %[[val_c8]], %[[val_10]] : i32
+!CHECK: %[[val_12:.*]] = hlfir.no_reassoc %[[val_11]] : i32
+!CHECK: omp.atomic.update %[[val_9]] : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG:.*]]: i32):
+!CHECK: %[[val_18:.*]] = arith.muli %[[ARG]], %[[val_12]] : i32
+!CHECK: omp.yield(%[[val_18]] : i32)
+!CHECK: }
+!CHECK: %[[val_c2_0:.*]] = arith.constant 2 : index
+!CHECK: %[[val_13:.*]] = hlfir.designate %[[val_8]]#0 (%[[val_c2_0]]) : (!fir.ref<!fir.array<5xi32>>, index) -> !fir.ref<i32>
+!CHECK: %[[val_c8_1:.*]] = arith.constant 8 : i32
+!CHECK: omp.atomic.update %[[val_13:.*]] : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG:.*]]: i32):
+!CHECK: %[[val_18:.*]] = arith.divui %[[ARG]], %[[val_c8_1]] : i32
+!CHECK: omp.yield(%[[val_18]] : i32)
+!CHECK: }
+!CHECK: %[[val_c8_2:.*]] = arith.constant 8 : i32
+!CHECK: %[[val_c4:.*]] = arith.constant 4 : index
+!CHECK: %[[val_14:.*]] = hlfir.designate %[[val_8]]#0 (%[[val_c4]]) : (!fir.ref<!fir.array<5xi32>>, index) -> !fir.ref<i32>
+!CHECK: %[[val_15:.*]] = fir.load %[[val_14]] : !fir.ref<i32>
+!CHECK: %[[val_16:.*]] = arith.addi %[[val_c8_2]], %[[val_15]] : i32
+!CHECK: %[[val_17:.*]] = hlfir.no_reassoc %[[val_16]] : i32
+!CHECK: omp.atomic.update %[[val_5]]#1 : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG:.*]]: i32):
+!CHECK: %[[val_18:.*]] = arith.addi %[[ARG]], %[[val_17]] : i32
+!CHECK: omp.yield(%[[val_18]] : i32)
+!CHECK: }
+!CHECK: %[[val_c8_3:.*]] = arith.constant 8 : i32
+!CHECK: omp.atomic.update %[[val_5]]#1 : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG]]: i32):
+!CHECK: %[[val_18:.*]] = arith.subi %[[ARG]], %[[val_c8_3]] : i32
+!CHECK: omp.yield(%[[val_18]] : i32)
+!CHECK: }
+!CHECK: return
+!CHECK: }
+program sample
+
+ integer :: x
+ integer, dimension(5) :: y
+ integer :: a, b
+
+ !$omp atomic update
+ y(2) = (8 + x) * y(2)
+ !$omp end atomic
+
+ !$omp atomic update
+ y(2) = y(2) / 8
+ !$omp end atomic
+
+ !$omp atomic update
+ x = (8 + y(4)) + x
+ !$omp end atomic
+
+ !$omp atomic update
+ x = 8 - x
+ !$omp end atomic
+
+end program sample
More information about the flang-commits
mailing list