[flang-commits] [flang] b6b0756 - [flang] Allow lowering of sub-expressions to be overridden (#69944)
via flang-commits
flang-commits at lists.llvm.org
Wed Oct 25 00:22:28 PDT 2023
Author: jeanPerier
Date: 2023-10-25T09:22:23+02:00
New Revision: b6b0756ce5c4e2e07d7f6f1f430d3d29afe9a8a8
URL: https://github.com/llvm/llvm-project/commit/b6b0756ce5c4e2e07d7f6f1f430d3d29afe9a8a8
DIFF: https://github.com/llvm/llvm-project/commit/b6b0756ce5c4e2e07d7f6f1f430d3d29afe9a8a8.diff
LOG: [flang] Allow lowering of sub-expressions to be overridden (#69944)
OpenACC/OpenMP atomic lowering needs a finer control over expression
lowering. This patch allows mapping evaluate::Expr<T> to mlir::Value so
that any subsequent expression lowering will use these values when an
operand is a mapped Expr<T>.
This is an alternative to
https://github.com/llvm/llvm-project/pull/69866 From which I took the
test and some of the logic to extract the non-atomic sub-expression.
---------
Co-authored-by: Nimish Mishra <neelam.nimish at gmail.com>
Added:
flang/test/Lower/OpenMP/common-atomic-lowering.f90
Modified:
flang/include/flang/Lower/AbstractConverter.h
flang/lib/Lower/Bridge.cpp
flang/lib/Lower/ConvertExpr.cpp
flang/lib/Lower/ConvertExprToHLFIR.cpp
flang/lib/Lower/DirectivesCommon.h
flang/test/Lower/OpenACC/acc-atomic-capture.f90
flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
flang/test/Lower/OpenACC/acc-atomic-update.f90
flang/test/Lower/OpenMP/FIR/atomic-capture.f90
flang/test/Lower/OpenMP/FIR/atomic-update.f90
flang/test/Lower/OpenMP/atomic-update-hlfir.f90
Removed:
################################################################################
diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index c792e75f1146499..fa67729fe036684 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -60,6 +60,8 @@ using SomeExpr = Fortran::evaluate::Expr<Fortran::evaluate::SomeType>;
using SymbolRef = Fortran::common::Reference<const Fortran::semantics::Symbol>;
class StatementContext;
+using ExprToValueMap = llvm::DenseMap<const SomeExpr *, mlir::Value>;
+
//===----------------------------------------------------------------------===//
// AbstractConverter interface
//===----------------------------------------------------------------------===//
@@ -90,6 +92,14 @@ class AbstractConverter {
/// added or replaced at the inner-most level of the local symbol map.
virtual void bindSymbol(SymbolRef sym, const fir::ExtendedValue &exval) = 0;
+ /// Override lowering of expression with pre-lowered values.
+ /// Associate mlir::Value to evaluate::Expr. All subsequent call to
+ /// genExprXXX() will replace any occurrence of an overridden
+ /// expression in the expression tree by the pre-lowered values.
+ virtual void overrideExprValues(const ExprToValueMap *) = 0;
+ void resetExprOverrides() { overrideExprValues(nullptr); }
+ virtual const ExprToValueMap *getExprOverrides() = 0;
+
/// Get the label set associated with a symbol.
virtual bool lookupLabelSet(SymbolRef sym, pft::LabelSet &labelSet) = 0;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index ff31625c7734f16..761cedb97fb959e 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -513,6 +513,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
addSymbol(sym, exval, /*forced=*/true);
}
+ void
+ overrideExprValues(const Fortran::lower::ExprToValueMap *map) override final {
+ exprValueOverrides = map;
+ }
+
+ const Fortran::lower::ExprToValueMap *getExprOverrides() override final {
+ return exprValueOverrides;
+ }
+
bool lookupLabelSet(Fortran::lower::SymbolRef sym,
Fortran::lower::pft::LabelSet &labelSet) override final {
Fortran::lower::pft::FunctionLikeUnit &owningProc =
@@ -4903,6 +4912,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// Whether an OpenMP target region or declare target function/subroutine
/// intended for device offloading has been detected
bool ompDeviceCodeFound = false;
+
+ const Fortran::lower::ExprToValueMap *exprValueOverrides{nullptr};
};
} // namespace
diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index 6d2ac62b61b74c3..1a2b3856c526716 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -2963,8 +2963,21 @@ class ScalarExprLowering {
return asArray(x);
}
+ template <typename A>
+ mlir::Value getIfOverridenExpr(const Fortran::evaluate::Expr<A> &x) {
+ if (const Fortran::lower::ExprToValueMap *map =
+ converter.getExprOverrides()) {
+ Fortran::lower::SomeExpr someExpr = toEvExpr(x);
+ if (auto match = map->find(&someExpr); match != map->end())
+ return match->second;
+ }
+ return mlir::Value{};
+ }
+
template <typename A>
ExtValue gen(const Fortran::evaluate::Expr<A> &x) {
+ if (mlir::Value val = getIfOverridenExpr(x))
+ return val;
// Whole array symbols or components, and results of transformational
// functions already have a storage and the scalar expression lowering path
// is used to not create a new temporary storage.
@@ -2978,6 +2991,8 @@ class ScalarExprLowering {
}
template <typename A>
ExtValue genval(const Fortran::evaluate::Expr<A> &x) {
+ if (mlir::Value val = getIfOverridenExpr(x))
+ return val;
if (isScalar(x) || Fortran::evaluate::UnwrapWholeSymbolDataRef(x) ||
inInitializer)
return std::visit([&](const auto &e) { return genval(e); }, x.u);
@@ -2987,6 +3002,8 @@ class ScalarExprLowering {
template <int KIND>
ExtValue genval(const Fortran::evaluate::Expr<Fortran::evaluate::Type<
Fortran::common::TypeCategory::Logical, KIND>> &exp) {
+ if (mlir::Value val = getIfOverridenExpr(exp))
+ return val;
return std::visit([&](const auto &e) { return genval(e); }, exp.u);
}
diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index 4cf29c9aecbf577..1da6a5bdd54784e 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -1423,6 +1423,17 @@ class HlfirBuilder {
template <typename T>
hlfir::EntityWithAttributes gen(const Fortran::evaluate::Expr<T> &expr) {
+ if (const Fortran::lower::ExprToValueMap *map =
+ getConverter().getExprOverrides()) {
+ if constexpr (std::is_same_v<T, Fortran::evaluate::SomeType>) {
+ if (auto match = map->find(&expr); match != map->end())
+ return hlfir::EntityWithAttributes{match->second};
+ } else {
+ Fortran::lower::SomeExpr someExpr = toEvExpr(expr);
+ if (auto match = map->find(&someExpr); match != map->end())
+ return hlfir::EntityWithAttributes{match->second};
+ }
+ }
return std::visit([&](const auto &x) { return gen(x); }, expr.u);
}
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index eef92160ae1fd45..558fa8931f630ee 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -200,62 +200,13 @@ static inline void genOmpAccAtomicUpdateStatement(
mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable,
const Fortran::parser::Expr &assignmentStmtExpr,
[[maybe_unused]] const AtomicListT *leftHandClauseList,
- [[maybe_unused]] const AtomicListT *rightHandClauseList) {
+ [[maybe_unused]] const AtomicListT *rightHandClauseList,
+ mlir::Operation *atomicCaptureOp = nullptr) {
// Generate `omp.atomic.update` operation for atomic assignment statements
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Location currentLocation = converter.getCurrentLocation();
- const auto *varDesignator =
- std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>(
- &assignmentStmtVariable.u);
- assert(varDesignator && "Variable designator for atomic update assignment "
- "statement does not exist");
- const Fortran::parser::Name *name =
- Fortran::semantics::getDesignatorNameIfDataRef(varDesignator->value());
- if (!name)
- TODO(converter.getCurrentLocation(),
- "Array references as atomic update variable");
- assert(name && name->symbol &&
- "No symbol attached to atomic update variable");
- if (Fortran::semantics::IsAllocatableOrPointer(name->symbol->GetUltimate()))
- converter.bindSymbol(*name->symbol, lhsAddr);
-
- // Lowering is in two steps :
- // subroutine sb
- // integer :: a, b
- // !$omp atomic update
- // a = a + b
- // end subroutine
- //
- // 1. Lower to scf.execute_region_op
- //
- // func.func @_QPsb() {
- // %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
- // %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
- // %2 = scf.execute_region -> i32 {
- // %3 = fir.load %0 : !fir.ref<i32>
- // %4 = fir.load %1 : !fir.ref<i32>
- // %5 = arith.addi %3, %4 : i32
- // scf.yield %5 : i32
- // }
- // return
- // }
- auto tempOp =
- firOpBuilder.create<mlir::scf::ExecuteRegionOp>(currentLocation, varType);
- firOpBuilder.createBlock(&tempOp.getRegion());
- mlir::Block &block = tempOp.getRegion().back();
- firOpBuilder.setInsertionPointToEnd(&block);
- Fortran::lower::StatementContext stmtCtx;
- mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx));
- mlir::Value convertResult =
- firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
- // Insert the terminator: YieldOp.
- firOpBuilder.create<mlir::scf::YieldOp>(currentLocation, convertResult);
- firOpBuilder.setInsertionPointToStart(&block);
-
- // 2. Create the omp.atomic.update Operation using the Operations in the
- // temporary scf.execute_region Operation.
+ // Create the omp.atomic.update Operation
//
// func.func @_QPsb() {
// %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
@@ -269,11 +220,37 @@ static inline void genOmpAccAtomicUpdateStatement(
// }
// return
// }
- mlir::Value updateVar = converter.getSymbolAddress(*name->symbol);
- if (auto decl = updateVar.getDefiningOp<hlfir::DeclareOp>())
- updateVar = decl.getBase();
- firOpBuilder.setInsertionPointAfter(tempOp);
+ Fortran::lower::ExprToValueMap exprValueOverrides;
+ // Lower any non atomic sub-expression before the atomic operation, and
+ // map its lowered value to the semantic representation.
+ const Fortran::lower::SomeExpr *nonAtomicSubExpr{nullptr};
+ std::visit(
+ [&](const auto &op) -> void {
+ using T = std::decay_t<decltype(op)>;
+ if constexpr (std::is_base_of<Fortran::parser::Expr::IntrinsicBinary,
+ T>::value) {
+ const auto &exprLeft{std::get<0>(op.t)};
+ const auto &exprRight{std::get<1>(op.t)};
+ if (exprLeft.value().source == assignmentStmtVariable.GetSource())
+ nonAtomicSubExpr = Fortran::semantics::GetExpr(exprRight);
+ else
+ nonAtomicSubExpr = Fortran::semantics::GetExpr(exprLeft);
+ }
+ },
+ assignmentStmtExpr.u);
+ StatementContext nonAtomicStmtCtx;
+ if (nonAtomicSubExpr) {
+ // Generate non atomic part before all the atomic operations.
+ auto insertionPoint = firOpBuilder.saveInsertionPoint();
+ if (atomicCaptureOp)
+ firOpBuilder.setInsertionPoint(atomicCaptureOp);
+ mlir::Value nonAtomicVal = fir::getBase(converter.genExprValue(
+ currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
+ exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
+ if (atomicCaptureOp)
+ firOpBuilder.restoreInsertionPoint(insertionPoint);
+ }
mlir::Operation *atomicUpdateOp = nullptr;
if constexpr (std::is_same<AtomicListT,
@@ -289,10 +266,10 @@ static inline void genOmpAccAtomicUpdateStatement(
genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList,
hint, memoryOrder);
atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>(
- currentLocation, updateVar, hint, memoryOrder);
+ currentLocation, lhsAddr, hint, memoryOrder);
} else {
atomicUpdateOp = firOpBuilder.create<mlir::acc::AtomicUpdateOp>(
- currentLocation, updateVar);
+ currentLocation, lhsAddr);
}
llvm::SmallVector<mlir::Type> varTys = {varType};
@@ -301,38 +278,25 @@ static inline void genOmpAccAtomicUpdateStatement(
mlir::Value val =
fir::getBase(atomicUpdateOp->getRegion(0).front().getArgument(0));
- llvm::SmallVector<mlir::Operation *> ops;
- for (mlir::Operation &op : tempOp.getRegion().getOps())
- ops.push_back(&op);
-
- // SCF Yield is converted to OMP Yield. All other operations are copied
- for (mlir::Operation *op : ops) {
- if (auto y = mlir::dyn_cast<mlir::scf::YieldOp>(op)) {
- firOpBuilder.setInsertionPointToEnd(
- &atomicUpdateOp->getRegion(0).front());
- if constexpr (std::is_same<AtomicListT,
- Fortran::parser::OmpAtomicClauseList>()) {
- firOpBuilder.create<mlir::omp::YieldOp>(currentLocation,
- y.getResults());
- } else {
- firOpBuilder.create<mlir::acc::YieldOp>(currentLocation,
- y.getResults());
- }
- op->erase();
+ exprValueOverrides.try_emplace(
+ Fortran::semantics::GetExpr(assignmentStmtVariable), val);
+ {
+ // statement context inside the atomic block.
+ converter.overrideExprValues(&exprValueOverrides);
+ Fortran::lower::StatementContext atomicStmtCtx;
+ mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(assignmentStmtExpr), atomicStmtCtx));
+ mlir::Value convertResult =
+ firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
+ if constexpr (std::is_same<AtomicListT,
+ Fortran::parser::OmpAtomicClauseList>()) {
+ firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, convertResult);
} else {
- op->remove();
- atomicUpdateOp->getRegion(0).front().push_back(op);
+ firOpBuilder.create<mlir::acc::YieldOp>(currentLocation, convertResult);
}
+ converter.resetExprOverrides();
}
-
- // Remove the load and replace all uses of load with the block argument
- for (mlir::Operation &op : atomicUpdateOp->getRegion(0).getOps()) {
- fir::LoadOp y = mlir::dyn_cast<fir::LoadOp>(&op);
- if (y && y.getMemref() == updateVar)
- y.getRes().replaceAllUsesWith(val);
- }
-
- tempOp.erase();
+ firOpBuilder.setInsertionPointAfter(atomicUpdateOp);
}
/// Processes an atomic construct with write clause.
@@ -423,11 +387,7 @@ void genOmpAccAtomicUpdate(Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext stmtCtx;
mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
*Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
- mlir::Type varType =
- fir::getBase(
- converter.genExprValue(
- *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx))
- .getType();
+ mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
genOmpAccAtomicUpdateStatement<AtomicListT>(
converter, lhsAddr, varType, assignmentStmtVariable, assignmentStmtExpr,
leftHandClauseList, rightHandClauseList);
@@ -450,11 +410,7 @@ void genOmpAtomic(Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext stmtCtx;
mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
*Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
- mlir::Type varType =
- fir::getBase(
- converter.genExprValue(
- *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx))
- .getType();
+ mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
// If atomic-clause is not present on the construct, the behaviour is as if
// the update clause is specified (for both OpenMP and OpenACC).
genOmpAccAtomicUpdateStatement<AtomicListT>(
@@ -551,7 +507,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
genOmpAccAtomicUpdateStatement<AtomicListT>(
converter, stmt1RHSArg, stmt2VarType, stmt2Var, stmt2Expr,
/*leftHandClauseList=*/nullptr,
- /*rightHandClauseList=*/nullptr);
+ /*rightHandClauseList=*/nullptr, atomicCaptureOp);
} else {
// Atomic capture construct is of the form [capture-stmt, write-stmt]
const Fortran::semantics::SomeExpr &fromExpr =
@@ -580,7 +536,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
genOmpAccAtomicUpdateStatement<AtomicListT>(
converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
/*leftHandClauseList=*/nullptr,
- /*rightHandClauseList=*/nullptr);
+ /*rightHandClauseList=*/nullptr, atomicCaptureOp);
}
firOpBuilder.setInsertionPointToEnd(&block);
if constexpr (std::is_same<AtomicListT,
diff --git a/flang/test/Lower/OpenACC/acc-atomic-capture.f90 b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
index 1a5fe8d57c1533a..382991cf7221ba7 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-capture.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
@@ -7,11 +7,11 @@ program acc_atomic_capture_test
!CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
!CHECK: %[[Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"}
+!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: acc.atomic.capture {
!CHECK: acc.atomic.read %[[X]] = %[[Y]] : !fir.ref<i32>
!CHECK: acc.atomic.update %[[Y]] : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: %[[result:.*]] = arith.addi %[[temp]], %[[ARG]] : i32
!CHECK: acc.yield %[[result]] : i32
!CHECK: }
@@ -23,10 +23,10 @@ program acc_atomic_capture_test
!$acc end atomic
+!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: acc.atomic.capture {
!CHECK: acc.atomic.update %[[Y]] : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: %[[result:.*]] = arith.muli %[[temp]], %[[ARG]] : i32
!CHECK: acc.yield %[[result]] : i32
!CHECK: }
@@ -76,12 +76,12 @@ subroutine pointers_in_atomic_capture()
!CHECK: %[[loaded_A_addr:.*]] = fir.box_addr %[[loaded_A]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
!CHECK: %[[loaded_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
!CHECK: %[[loaded_B_addr:.*]] = fir.box_addr %[[loaded_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
-!CHECK: acc.atomic.capture {
-!CHECK: acc.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
-!CHECK: ^bb0(%[[ARG:.*]]: i32):
!CHECK: %[[PRIVATE_LOADED_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
!CHECK: %[[PRIVATE_LOADED_B_addr:.*]] = fir.box_addr %[[PRIVATE_LOADED_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
!CHECK: %[[loaded_value:.*]] = fir.load %[[PRIVATE_LOADED_B_addr]] : !fir.ptr<i32>
+!CHECK: acc.atomic.capture {
+!CHECK: acc.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
+!CHECK: ^bb0(%[[ARG:.*]]: i32):
!CHECK: %[[result:.*]] = arith.addi %[[ARG]], %[[loaded_value]] : i32
!CHECK: acc.yield %[[result]] : i32
!CHECK: }
diff --git a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90 b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
index 24dd0ee5a8999e4..b2a993ddd825105 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
@@ -14,9 +14,9 @@ subroutine sb
!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFsbEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[Y_REF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFsbEy"}
!CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_REF]] {uniq_name = "_QFsbEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-!CHECK: acc.atomic.update %[[X_DECL]]#0 : !fir.ref<i32> {
+!CHECK: %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
+!CHECK: acc.atomic.update %[[X_DECL]]#1 : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG_X:.*]]: i32):
-!CHECK: %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[X_UPDATE_VAL:.*]] = arith.addi %[[ARG_X]], %[[Y_VAL]] : i32
!CHECK: acc.yield %[[X_UPDATE_VAL]] : i32
!CHECK: }
diff --git a/flang/test/Lower/OpenACC/acc-atomic-update.f90 b/flang/test/Lower/OpenACC/acc-atomic-update.f90
index 546d012982e23ba..96e1d5e58e6f482 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-update.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-update.f90
@@ -31,25 +31,25 @@ program acc_atomic_update_test
!CHECK: %{{.*}} = fir.convert %[[D_ADDR]] : (!fir.ref<i32>) -> !fir.ptr<i32>
!CHECK: fir.store {{.*}} to %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
!CHECK: %[[LOADED_A:.*]] = fir.load %[[A_ADDR]] : !fir.ref<!fir.ptr<i32>>
+!CHECK: %[[LOADED_B:.*]] = fir.load %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
+!CHECK: %{{.*}} = fir.load %[[LOADED_B]] : !fir.ptr<i32>
!CHECK: acc.atomic.update %[[LOADED_A]] : !fir.ptr<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[LOADED_B:.*]] = fir.load %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
-!CHECK: %{{.*}} = fir.load %[[LOADED_B]] : !fir.ptr<i32>
!CHECK: %[[RESULT:.*]] = arith.addi %[[ARG]], %{{.*}} : i32
!CHECK: acc.yield %[[RESULT]] : i32
!CHECK: }
!$acc atomic update
a = a + b
+!CHECK: {{.*}} = arith.constant 1 : i32
!CHECK: acc.atomic.update %[[Y]] : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: {{.*}} = arith.constant 1 : i32
!CHECK: %[[RESULT:.*]] = arith.addi %[[ARG]], {{.*}} : i32
!CHECK: acc.yield %[[RESULT]] : i32
!CHECK: }
+!CHECK: %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: acc.atomic.update %[[Z]] : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: %[[RESULT:.*]] = arith.muli %[[LOADED_X]], %[[ARG]] : i32
!CHECK: acc.yield %[[RESULT]] : i32
!CHECK: }
@@ -58,10 +58,10 @@ program acc_atomic_update_test
!$acc atomic update
z = x * z
+!CHECK: %[[C1_VAL:.*]] = arith.constant 1 : i32
!CHECK: acc.atomic.update %[[I1]] : !fir.ref<i8> {
!CHECK: ^bb0(%[[VAL:.*]]: i8):
!CHECK: %[[CVT_VAL:.*]] = fir.convert %[[VAL]] : (i8) -> i32
-!CHECK: %[[C1_VAL:.*]] = arith.constant 1 : i32
!CHECK: %[[ADD_VAL:.*]] = arith.addi %[[CVT_VAL]], %[[C1_VAL]] : i32
!CHECK: %[[UPDATED_VAL:.*]] = fir.convert %[[ADD_VAL]] : (i32) -> i8
!CHECK: acc.yield %[[UPDATED_VAL]] : i8
diff --git a/flang/test/Lower/OpenMP/FIR/atomic-capture.f90 b/flang/test/Lower/OpenMP/FIR/atomic-capture.f90
index f48a5efaf4354e7..a8b04a2e90cd46b 100644
--- a/flang/test/Lower/OpenMP/FIR/atomic-capture.f90
+++ b/flang/test/Lower/OpenMP/FIR/atomic-capture.f90
@@ -8,11 +8,11 @@ program OmpAtomicCapture
!CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
!CHECK: %[[Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"}
+!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: omp.atomic.capture memory_order(release) {
!CHECK: omp.atomic.read %[[X]] = %[[Y]] : !fir.ref<i32>
!CHECK: omp.atomic.update %[[Y]] : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: %[[result:.*]] = arith.addi %[[temp]], %[[ARG]] : i32
!CHECK: omp.yield(%[[result]] : i32)
!CHECK: }
@@ -24,10 +24,10 @@ program OmpAtomicCapture
!$omp end atomic
+!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: omp.atomic.capture hint(uncontended) {
!CHECK: omp.atomic.update %[[Y]] : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: %[[result:.*]] = arith.muli %[[temp]], %[[ARG]] : i32
!CHECK: omp.yield(%[[result]] : i32)
!CHECK: }
@@ -94,12 +94,12 @@ subroutine pointers_in_atomic_capture()
!CHECK: %[[loaded_A_addr:.*]] = fir.box_addr %[[loaded_A]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
!CHECK: %[[loaded_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
!CHECK: %[[loaded_B_addr:.*]] = fir.box_addr %[[loaded_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
-!CHECK: omp.atomic.capture {
-!CHECK: omp.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
-!CHECK: ^bb0(%[[ARG:.*]]: i32):
!CHECK: %[[PRIVATE_LOADED_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
!CHECK: %[[PRIVATE_LOADED_B_addr:.*]] = fir.box_addr %[[PRIVATE_LOADED_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
!CHECK: %[[loaded_value:.*]] = fir.load %[[PRIVATE_LOADED_B_addr]] : !fir.ptr<i32>
+!CHECK: omp.atomic.capture {
+!CHECK: omp.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
+!CHECK: ^bb0(%[[ARG:.*]]: i32):
!CHECK: %[[result:.*]] = arith.addi %[[ARG]], %[[loaded_value]] : i32
!CHECK: omp.yield(%[[result]] : i32)
!CHECK: }
diff --git a/flang/test/Lower/OpenMP/FIR/atomic-update.f90 b/flang/test/Lower/OpenMP/FIR/atomic-update.f90
index d0185d2f3b14dfe..56ced10901ab677 100644
--- a/flang/test/Lower/OpenMP/FIR/atomic-update.f90
+++ b/flang/test/Lower/OpenMP/FIR/atomic-update.f90
@@ -32,25 +32,25 @@ program OmpAtomicUpdate
!CHECK: %{{.*}} = fir.convert %[[D_ADDR]] : (!fir.ref<i32>) -> !fir.ptr<i32>
!CHECK: fir.store {{.*}} to %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
!CHECK: %[[LOADED_A:.*]] = fir.load %[[A_ADDR]] : !fir.ref<!fir.ptr<i32>>
+!CHECK: %[[LOADED_B:.*]] = fir.load %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
+!CHECK: %{{.*}} = fir.load %[[LOADED_B]] : !fir.ptr<i32>
!CHECK: omp.atomic.update %[[LOADED_A]] : !fir.ptr<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[LOADED_B:.*]] = fir.load %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
-!CHECK: %{{.*}} = fir.load %[[LOADED_B]] : !fir.ptr<i32>
!CHECK: %[[RESULT:.*]] = arith.addi %[[ARG]], %{{.*}} : i32
!CHECK: omp.yield(%[[RESULT]] : i32)
!CHECK: }
!$omp atomic update
a = a + b
+!CHECK: {{.*}} = arith.constant 1 : i32
!CHECK: omp.atomic.update %[[Y]] : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: {{.*}} = arith.constant 1 : i32
!CHECK: %[[RESULT:.*]] = arith.addi %[[ARG]], {{.*}} : i32
!CHECK: omp.yield(%[[RESULT]] : i32)
!CHECK: }
+!CHECK: %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: omp.atomic.update %[[Z]] : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: %[[RESULT:.*]] = arith.muli %[[LOADED_X]], %[[ARG]] : i32
!CHECK: omp.yield(%[[RESULT]] : i32)
!CHECK: }
@@ -59,9 +59,9 @@ program OmpAtomicUpdate
!$omp atomic update
z = x * z
+!CHECK: %{{.*}} = arith.constant 1 : i32
!CHECK: omp.atomic.update memory_order(relaxed) hint(uncontended) %[[X]] : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %{{.*}} = arith.constant 1 : i32
!CHECK: %[[RESULT:.*]] = arith.subi %[[ARG]], {{.*}} : i32
!CHECK: omp.yield(%[[RESULT]] : i32)
!CHECK: }
@@ -75,9 +75,9 @@ program OmpAtomicUpdate
!CHECK: %[[RESULT:.*]] = arith.select %{{.*}}, %{{.*}}, %[[LOADED_Z]] : i32
!CHECK: omp.yield(%[[RESULT]] : i32)
!CHECK: }
+!CHECK: %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: omp.atomic.update memory_order(relaxed) hint(contended) %[[Z]] : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: %[[RESULT:.*]] = arith.addi %[[ARG]], %[[LOADED_X]] : i32
!CHECK: omp.yield(%[[RESULT]] : i32)
!CHECK: }
@@ -88,15 +88,15 @@ program OmpAtomicUpdate
!$omp atomic relaxed hint(omp_sync_hint_contended)
z = z + x
+!CHECK: %{{.*}} = arith.constant 10 : i32
!CHECK: omp.atomic.update memory_order(release) hint(contended) %[[Z]] : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %{{.*}} = arith.constant 10 : i32
!CHECK: %[[RESULT:.*]] = arith.muli {{.*}}, %[[ARG]] : i32
!CHECK: omp.yield(%[[RESULT]] : i32)
!CHECK: }
+!CHECK: %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref<i32>
!CHECK: omp.atomic.update memory_order(release) hint(speculative) %[[X]] : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref<i32>
!CHECK: %[[RESULT:.*]] = arith.divsi %[[ARG]], %[[LOADED_Z]] : i32
!CHECK: omp.yield(%[[RESULT]] : i32)
!CHECK: }
@@ -106,15 +106,15 @@ program OmpAtomicUpdate
!$omp atomic hint(omp_lock_hint_speculative) update release
x = x / z
+!CHECK: %{{.*}} = arith.constant 10 : i32
!CHECK: omp.atomic.update memory_order(seq_cst) hint(nonspeculative) %[[Y]] : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %{{.*}} = arith.constant 10 : i32
!CHECK: %[[RESULT:.*]] = arith.addi %{{.*}}, %[[ARG]] : i32
!CHECK: omp.yield(%[[RESULT]] : i32)
!CHECK: }
+!CHECK: %[[LOADED_Y:.*]] = fir.load %[[Y]] : !fir.ref<i32>
!CHECK: omp.atomic.update memory_order(seq_cst) %[[Z]] : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[LOADED_Y:.*]] = fir.load %[[Y]] : !fir.ref<i32>
!CHECK: %[[RESULT:.*]] = arith.addi %[[LOADED_Y]], %[[ARG]] : i32
!CHECK: omp.yield(%[[RESULT]] : i32)
!CHECK: }
@@ -123,10 +123,10 @@ program OmpAtomicUpdate
!$omp atomic seq_cst update
z = y + z
+!CHECK: %[[C1_VAL:.*]] = arith.constant 1 : i32
!CHECK: omp.atomic.update %[[I1]] : !fir.ref<i8> {
!CHECK: ^bb0(%[[VAL:.*]]: i8):
!CHECK: %[[CVT_VAL:.*]] = fir.convert %[[VAL]] : (i8) -> i32
-!CHECK: %[[C1_VAL:.*]] = arith.constant 1 : i32
!CHECK: %[[ADD_VAL:.*]] = arith.addi %[[CVT_VAL]], %[[C1_VAL]] : i32
!CHECK: %[[UPDATED_VAL:.*]] = fir.convert %[[ADD_VAL]] : (i32) -> i8
!CHECK: omp.yield(%[[UPDATED_VAL]] : i8)
diff --git a/flang/test/Lower/OpenMP/atomic-update-hlfir.f90 b/flang/test/Lower/OpenMP/atomic-update-hlfir.f90
index f00ed495ae6f89f..329009ab8ef8e9b 100644
--- a/flang/test/Lower/OpenMP/atomic-update-hlfir.f90
+++ b/flang/test/Lower/OpenMP/atomic-update-hlfir.f90
@@ -14,9 +14,9 @@ subroutine sb
!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFsbEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[Y_REF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFsbEy"}
!CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_REF]] {uniq_name = "_QFsbEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-!CHECK: omp.atomic.update %[[X_DECL]]#0 : !fir.ref<i32> {
+!CHECK: %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
+!CHECK: omp.atomic.update %[[X_DECL]]#1 : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG_X:.*]]: i32):
-!CHECK: %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[X_UPDATE_VAL:.*]] = arith.addi %[[ARG_X]], %[[Y_VAL]] : i32
!CHECK: omp.yield(%[[X_UPDATE_VAL]] : i32)
!CHECK: }
diff --git a/flang/test/Lower/OpenMP/common-atomic-lowering.f90 b/flang/test/Lower/OpenMP/common-atomic-lowering.f90
new file mode 100644
index 000000000000000..076091b44b33a71
--- /dev/null
+++ b/flang/test/Lower/OpenMP/common-atomic-lowering.f90
@@ -0,0 +1,74 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+!CHECK: func.func @_QQmain() attributes {fir.bindc_name = "sample"} {
+!CHECK: %[[val_0:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFEa"}
+!CHECK: %[[val_1:.*]]:2 = hlfir.declare %[[val_0]] {uniq_name = "_QFEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[val_2:.*]] = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFEb"}
+!CHECK: %[[val_3:.*]]:2 = hlfir.declare %[[val_2]] {uniq_name = "_QFEb"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[val_4:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
+!CHECK: %[[val_5:.*]]:2 = hlfir.declare %[[val_4]] {uniq_name = "_QFEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[val_c5:.*]] = arith.constant 5 : index
+!CHECK: %[[val_6:.*]] = fir.alloca !fir.array<5xi32> {bindc_name = "y", uniq_name = "_QFEy"}
+!CHECK: %[[val_7:.*]] = fir.shape %[[val_c5]] : (index) -> !fir.shape<1>
+!CHECK: %[[val_8:.*]]:2 = hlfir.declare %[[val_6]](%[[val_7]]) {uniq_name = "_QFEy"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>)
+!CHECK: %[[val_c2:.*]] = arith.constant 2 : index
+!CHECK: %[[val_9:.*]] = hlfir.designate %[[val_8]]#0 (%[[val_c2]]) : (!fir.ref<!fir.array<5xi32>>, index) -> !fir.ref<i32>
+!CHECK: %[[val_c8:.*]] = arith.constant 8 : i32
+!CHECK: %[[val_10:.*]] = fir.load %[[val_5]]#0 : !fir.ref<i32>
+!CHECK: %[[val_11:.*]] = arith.addi %[[val_c8]], %[[val_10]] : i32
+!CHECK: %[[val_12:.*]] = hlfir.no_reassoc %[[val_11]] : i32
+!CHECK: omp.atomic.update %[[val_9]] : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG:.*]]: i32):
+!CHECK: %[[val_18:.*]] = arith.muli %[[val_12]], %[[ARG]] : i32
+!CHECK: omp.yield(%[[val_18]] : i32)
+!CHECK: }
+!CHECK: %[[val_c2_0:.*]] = arith.constant 2 : index
+!CHECK: %[[val_13:.*]] = hlfir.designate %[[val_8]]#0 (%[[val_c2_0]]) : (!fir.ref<!fir.array<5xi32>>, index) -> !fir.ref<i32>
+!CHECK: %[[val_c8_1:.*]] = arith.constant 8 : i32
+!CHECK: omp.atomic.update %[[val_13:.*]] : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG:.*]]: i32):
+!CHECK: %[[val_18:.*]] = arith.divsi %[[ARG]], %[[val_c8_1]] : i32
+!CHECK: omp.yield(%[[val_18]] : i32)
+!CHECK: }
+!CHECK: %[[val_c8_2:.*]] = arith.constant 8 : i32
+!CHECK: %[[val_c4:.*]] = arith.constant 4 : index
+!CHECK: %[[val_14:.*]] = hlfir.designate %[[val_8]]#0 (%[[val_c4]]) : (!fir.ref<!fir.array<5xi32>>, index) -> !fir.ref<i32>
+!CHECK: %[[val_15:.*]] = fir.load %[[val_14]] : !fir.ref<i32>
+!CHECK: %[[val_16:.*]] = arith.addi %[[val_c8_2]], %[[val_15]] : i32
+!CHECK: %[[val_17:.*]] = hlfir.no_reassoc %[[val_16]] : i32
+!CHECK: omp.atomic.update %[[val_5]]#1 : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG:.*]]: i32):
+!CHECK: %[[val_18:.*]] = arith.addi %[[val_17]], %[[ARG]] : i32
+!CHECK: omp.yield(%[[val_18]] : i32)
+!CHECK: }
+!CHECK: %[[val_c8_3:.*]] = arith.constant 8 : i32
+!CHECK: omp.atomic.update %[[val_5]]#1 : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG]]: i32):
+!CHECK: %[[val_18:.*]] = arith.subi %[[val_c8_3]], %[[ARG]] : i32
+!CHECK: omp.yield(%[[val_18]] : i32)
+!CHECK: }
+!CHECK: return
+!CHECK: }
+program sample
+
+ integer :: x
+ integer, dimension(5) :: y
+ integer :: a, b
+
+ !$omp atomic update
+ y(2) = (8 + x) * y(2)
+ !$omp end atomic
+
+ !$omp atomic update
+ y(2) = y(2) / 8
+ !$omp end atomic
+
+ !$omp atomic update
+ x = (8 + y(4)) + x
+ !$omp end atomic
+
+ !$omp atomic update
+ x = 8 - x
+ !$omp end atomic
+
+end program sample
More information about the flang-commits
mailing list