[flang-commits] [flang] b6b0756 - [flang] Allow lowering of sub-expressions to be overridden (#69944)

via flang-commits flang-commits at lists.llvm.org
Wed Oct 25 00:22:28 PDT 2023


Author: jeanPerier
Date: 2023-10-25T09:22:23+02:00
New Revision: b6b0756ce5c4e2e07d7f6f1f430d3d29afe9a8a8

URL: https://github.com/llvm/llvm-project/commit/b6b0756ce5c4e2e07d7f6f1f430d3d29afe9a8a8
DIFF: https://github.com/llvm/llvm-project/commit/b6b0756ce5c4e2e07d7f6f1f430d3d29afe9a8a8.diff

LOG: [flang] Allow lowering of sub-expressions to be overridden (#69944)

OpenACC/OpenMP atomic lowering needs a finer control over expression
lowering. This patch allows mapping evaluate::Expr<T> to mlir::Value so
that any subsequent expression lowering will use these values when an
operand is a mapped Expr<T>.

This is an alternative to
https://github.com/llvm/llvm-project/pull/69866 From which I took the
test and some of the logic to extract the non-atomic sub-expression.

---------

Co-authored-by: Nimish Mishra <neelam.nimish at gmail.com>

Added: 
    flang/test/Lower/OpenMP/common-atomic-lowering.f90

Modified: 
    flang/include/flang/Lower/AbstractConverter.h
    flang/lib/Lower/Bridge.cpp
    flang/lib/Lower/ConvertExpr.cpp
    flang/lib/Lower/ConvertExprToHLFIR.cpp
    flang/lib/Lower/DirectivesCommon.h
    flang/test/Lower/OpenACC/acc-atomic-capture.f90
    flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
    flang/test/Lower/OpenACC/acc-atomic-update.f90
    flang/test/Lower/OpenMP/FIR/atomic-capture.f90
    flang/test/Lower/OpenMP/FIR/atomic-update.f90
    flang/test/Lower/OpenMP/atomic-update-hlfir.f90

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index c792e75f1146499..fa67729fe036684 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -60,6 +60,8 @@ using SomeExpr = Fortran::evaluate::Expr<Fortran::evaluate::SomeType>;
 using SymbolRef = Fortran::common::Reference<const Fortran::semantics::Symbol>;
 class StatementContext;
 
+using ExprToValueMap = llvm::DenseMap<const SomeExpr *, mlir::Value>;
+
 //===----------------------------------------------------------------------===//
 // AbstractConverter interface
 //===----------------------------------------------------------------------===//
@@ -90,6 +92,14 @@ class AbstractConverter {
   /// added or replaced at the inner-most level of the local symbol map.
   virtual void bindSymbol(SymbolRef sym, const fir::ExtendedValue &exval) = 0;
 
+  /// Override lowering of expression with pre-lowered values.
+  /// Associate mlir::Value to evaluate::Expr. All subsequent call to
+  /// genExprXXX() will replace any occurrence of an overridden
+  /// expression in the expression tree by the pre-lowered values.
+  virtual void overrideExprValues(const ExprToValueMap *) = 0;
+  void resetExprOverrides() { overrideExprValues(nullptr); }
+  virtual const ExprToValueMap *getExprOverrides() = 0;
+
   /// Get the label set associated with a symbol.
   virtual bool lookupLabelSet(SymbolRef sym, pft::LabelSet &labelSet) = 0;
 

diff  --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index ff31625c7734f16..761cedb97fb959e 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -513,6 +513,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     addSymbol(sym, exval, /*forced=*/true);
   }
 
+  void
+  overrideExprValues(const Fortran::lower::ExprToValueMap *map) override final {
+    exprValueOverrides = map;
+  }
+
+  const Fortran::lower::ExprToValueMap *getExprOverrides() override final {
+    return exprValueOverrides;
+  }
+
   bool lookupLabelSet(Fortran::lower::SymbolRef sym,
                       Fortran::lower::pft::LabelSet &labelSet) override final {
     Fortran::lower::pft::FunctionLikeUnit &owningProc =
@@ -4903,6 +4912,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// Whether an OpenMP target region or declare target function/subroutine
   /// intended for device offloading has been detected
   bool ompDeviceCodeFound = false;
+
+  const Fortran::lower::ExprToValueMap *exprValueOverrides{nullptr};
 };
 
 } // namespace

diff  --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index 6d2ac62b61b74c3..1a2b3856c526716 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -2963,8 +2963,21 @@ class ScalarExprLowering {
     return asArray(x);
   }
 
+  template <typename A>
+  mlir::Value getIfOverridenExpr(const Fortran::evaluate::Expr<A> &x) {
+    if (const Fortran::lower::ExprToValueMap *map =
+            converter.getExprOverrides()) {
+      Fortran::lower::SomeExpr someExpr = toEvExpr(x);
+      if (auto match = map->find(&someExpr); match != map->end())
+        return match->second;
+    }
+    return mlir::Value{};
+  }
+
   template <typename A>
   ExtValue gen(const Fortran::evaluate::Expr<A> &x) {
+    if (mlir::Value val = getIfOverridenExpr(x))
+      return val;
     // Whole array symbols or components, and results of transformational
     // functions already have a storage and the scalar expression lowering path
     // is used to not create a new temporary storage.
@@ -2978,6 +2991,8 @@ class ScalarExprLowering {
   }
   template <typename A>
   ExtValue genval(const Fortran::evaluate::Expr<A> &x) {
+    if (mlir::Value val = getIfOverridenExpr(x))
+      return val;
     if (isScalar(x) || Fortran::evaluate::UnwrapWholeSymbolDataRef(x) ||
         inInitializer)
       return std::visit([&](const auto &e) { return genval(e); }, x.u);
@@ -2987,6 +3002,8 @@ class ScalarExprLowering {
   template <int KIND>
   ExtValue genval(const Fortran::evaluate::Expr<Fortran::evaluate::Type<
                       Fortran::common::TypeCategory::Logical, KIND>> &exp) {
+    if (mlir::Value val = getIfOverridenExpr(exp))
+      return val;
     return std::visit([&](const auto &e) { return genval(e); }, exp.u);
   }
 

diff  --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index 4cf29c9aecbf577..1da6a5bdd54784e 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -1423,6 +1423,17 @@ class HlfirBuilder {
 
   template <typename T>
   hlfir::EntityWithAttributes gen(const Fortran::evaluate::Expr<T> &expr) {
+    if (const Fortran::lower::ExprToValueMap *map =
+            getConverter().getExprOverrides()) {
+      if constexpr (std::is_same_v<T, Fortran::evaluate::SomeType>) {
+        if (auto match = map->find(&expr); match != map->end())
+          return hlfir::EntityWithAttributes{match->second};
+      } else {
+        Fortran::lower::SomeExpr someExpr = toEvExpr(expr);
+        if (auto match = map->find(&someExpr); match != map->end())
+          return hlfir::EntityWithAttributes{match->second};
+      }
+    }
     return std::visit([&](const auto &x) { return gen(x); }, expr.u);
   }
 

diff  --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index eef92160ae1fd45..558fa8931f630ee 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -200,62 +200,13 @@ static inline void genOmpAccAtomicUpdateStatement(
     mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable,
     const Fortran::parser::Expr &assignmentStmtExpr,
     [[maybe_unused]] const AtomicListT *leftHandClauseList,
-    [[maybe_unused]] const AtomicListT *rightHandClauseList) {
+    [[maybe_unused]] const AtomicListT *rightHandClauseList,
+    mlir::Operation *atomicCaptureOp = nullptr) {
   // Generate `omp.atomic.update` operation for atomic assignment statements
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::Location currentLocation = converter.getCurrentLocation();
 
-  const auto *varDesignator =
-      std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>(
-          &assignmentStmtVariable.u);
-  assert(varDesignator && "Variable designator for atomic update assignment "
-                          "statement does not exist");
-  const Fortran::parser::Name *name =
-      Fortran::semantics::getDesignatorNameIfDataRef(varDesignator->value());
-  if (!name)
-    TODO(converter.getCurrentLocation(),
-         "Array references as atomic update variable");
-  assert(name && name->symbol &&
-         "No symbol attached to atomic update variable");
-  if (Fortran::semantics::IsAllocatableOrPointer(name->symbol->GetUltimate()))
-    converter.bindSymbol(*name->symbol, lhsAddr);
-
-  //  Lowering is in two steps :
-  //  subroutine sb
-  //    integer :: a, b
-  //    !$omp atomic update
-  //      a = a + b
-  //  end subroutine
-  //
-  //  1. Lower to scf.execute_region_op
-  //
-  //  func.func @_QPsb() {
-  //    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
-  //    %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
-  //    %2 = scf.execute_region -> i32 {
-  //      %3 = fir.load %0 : !fir.ref<i32>
-  //      %4 = fir.load %1 : !fir.ref<i32>
-  //      %5 = arith.addi %3, %4 : i32
-  //      scf.yield %5 : i32
-  //    }
-  //    return
-  //  }
-  auto tempOp =
-      firOpBuilder.create<mlir::scf::ExecuteRegionOp>(currentLocation, varType);
-  firOpBuilder.createBlock(&tempOp.getRegion());
-  mlir::Block &block = tempOp.getRegion().back();
-  firOpBuilder.setInsertionPointToEnd(&block);
-  Fortran::lower::StatementContext stmtCtx;
-  mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
-      *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx));
-  mlir::Value convertResult =
-      firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
-  // Insert the terminator: YieldOp.
-  firOpBuilder.create<mlir::scf::YieldOp>(currentLocation, convertResult);
-  firOpBuilder.setInsertionPointToStart(&block);
-
-  //  2. Create the omp.atomic.update Operation using the Operations in the
-  //     temporary scf.execute_region Operation.
+  //  Create the omp.atomic.update Operation
   //
   //  func.func @_QPsb() {
   //    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
@@ -269,11 +220,37 @@ static inline void genOmpAccAtomicUpdateStatement(
   //    }
   //    return
   //  }
-  mlir::Value updateVar = converter.getSymbolAddress(*name->symbol);
-  if (auto decl = updateVar.getDefiningOp<hlfir::DeclareOp>())
-    updateVar = decl.getBase();
 
-  firOpBuilder.setInsertionPointAfter(tempOp);
+  Fortran::lower::ExprToValueMap exprValueOverrides;
+  // Lower any non atomic sub-expression before the atomic operation, and
+  // map its lowered value to the semantic representation.
+  const Fortran::lower::SomeExpr *nonAtomicSubExpr{nullptr};
+  std::visit(
+      [&](const auto &op) -> void {
+        using T = std::decay_t<decltype(op)>;
+        if constexpr (std::is_base_of<Fortran::parser::Expr::IntrinsicBinary,
+                                      T>::value) {
+          const auto &exprLeft{std::get<0>(op.t)};
+          const auto &exprRight{std::get<1>(op.t)};
+          if (exprLeft.value().source == assignmentStmtVariable.GetSource())
+            nonAtomicSubExpr = Fortran::semantics::GetExpr(exprRight);
+          else
+            nonAtomicSubExpr = Fortran::semantics::GetExpr(exprLeft);
+        }
+      },
+      assignmentStmtExpr.u);
+  StatementContext nonAtomicStmtCtx;
+  if (nonAtomicSubExpr) {
+    // Generate non atomic part before all the atomic operations.
+    auto insertionPoint = firOpBuilder.saveInsertionPoint();
+    if (atomicCaptureOp)
+      firOpBuilder.setInsertionPoint(atomicCaptureOp);
+    mlir::Value nonAtomicVal = fir::getBase(converter.genExprValue(
+        currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
+    exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
+    if (atomicCaptureOp)
+      firOpBuilder.restoreInsertionPoint(insertionPoint);
+  }
 
   mlir::Operation *atomicUpdateOp = nullptr;
   if constexpr (std::is_same<AtomicListT,
@@ -289,10 +266,10 @@ static inline void genOmpAccAtomicUpdateStatement(
       genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList,
                                             hint, memoryOrder);
     atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>(
-        currentLocation, updateVar, hint, memoryOrder);
+        currentLocation, lhsAddr, hint, memoryOrder);
   } else {
     atomicUpdateOp = firOpBuilder.create<mlir::acc::AtomicUpdateOp>(
-        currentLocation, updateVar);
+        currentLocation, lhsAddr);
   }
 
   llvm::SmallVector<mlir::Type> varTys = {varType};
@@ -301,38 +278,25 @@ static inline void genOmpAccAtomicUpdateStatement(
   mlir::Value val =
       fir::getBase(atomicUpdateOp->getRegion(0).front().getArgument(0));
 
-  llvm::SmallVector<mlir::Operation *> ops;
-  for (mlir::Operation &op : tempOp.getRegion().getOps())
-    ops.push_back(&op);
-
-  // SCF Yield is converted to OMP Yield. All other operations are copied
-  for (mlir::Operation *op : ops) {
-    if (auto y = mlir::dyn_cast<mlir::scf::YieldOp>(op)) {
-      firOpBuilder.setInsertionPointToEnd(
-          &atomicUpdateOp->getRegion(0).front());
-      if constexpr (std::is_same<AtomicListT,
-                                 Fortran::parser::OmpAtomicClauseList>()) {
-        firOpBuilder.create<mlir::omp::YieldOp>(currentLocation,
-                                                y.getResults());
-      } else {
-        firOpBuilder.create<mlir::acc::YieldOp>(currentLocation,
-                                                y.getResults());
-      }
-      op->erase();
+  exprValueOverrides.try_emplace(
+      Fortran::semantics::GetExpr(assignmentStmtVariable), val);
+  {
+    // statement context inside the atomic block.
+    converter.overrideExprValues(&exprValueOverrides);
+    Fortran::lower::StatementContext atomicStmtCtx;
+    mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
+        *Fortran::semantics::GetExpr(assignmentStmtExpr), atomicStmtCtx));
+    mlir::Value convertResult =
+        firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
+    if constexpr (std::is_same<AtomicListT,
+                               Fortran::parser::OmpAtomicClauseList>()) {
+      firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, convertResult);
     } else {
-      op->remove();
-      atomicUpdateOp->getRegion(0).front().push_back(op);
+      firOpBuilder.create<mlir::acc::YieldOp>(currentLocation, convertResult);
     }
+    converter.resetExprOverrides();
   }
-
-  // Remove the load and replace all uses of load with the block argument
-  for (mlir::Operation &op : atomicUpdateOp->getRegion(0).getOps()) {
-    fir::LoadOp y = mlir::dyn_cast<fir::LoadOp>(&op);
-    if (y && y.getMemref() == updateVar)
-      y.getRes().replaceAllUsesWith(val);
-  }
-
-  tempOp.erase();
+  firOpBuilder.setInsertionPointAfter(atomicUpdateOp);
 }
 
 /// Processes an atomic construct with write clause.
@@ -423,11 +387,7 @@ void genOmpAccAtomicUpdate(Fortran::lower::AbstractConverter &converter,
   Fortran::lower::StatementContext stmtCtx;
   mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
-  mlir::Type varType =
-      fir::getBase(
-          converter.genExprValue(
-              *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx))
-          .getType();
+  mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
   genOmpAccAtomicUpdateStatement<AtomicListT>(
       converter, lhsAddr, varType, assignmentStmtVariable, assignmentStmtExpr,
       leftHandClauseList, rightHandClauseList);
@@ -450,11 +410,7 @@ void genOmpAtomic(Fortran::lower::AbstractConverter &converter,
   Fortran::lower::StatementContext stmtCtx;
   mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
-  mlir::Type varType =
-      fir::getBase(
-          converter.genExprValue(
-              *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx))
-          .getType();
+  mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
   // If atomic-clause is not present on the construct, the behaviour is as if
   // the update clause is specified (for both OpenMP and OpenACC).
   genOmpAccAtomicUpdateStatement<AtomicListT>(
@@ -551,7 +507,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
       genOmpAccAtomicUpdateStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt2VarType, stmt2Var, stmt2Expr,
           /*leftHandClauseList=*/nullptr,
-          /*rightHandClauseList=*/nullptr);
+          /*rightHandClauseList=*/nullptr, atomicCaptureOp);
     } else {
       // Atomic capture construct is of the form [capture-stmt, write-stmt]
       const Fortran::semantics::SomeExpr &fromExpr =
@@ -580,7 +536,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
     genOmpAccAtomicUpdateStatement<AtomicListT>(
         converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
         /*leftHandClauseList=*/nullptr,
-        /*rightHandClauseList=*/nullptr);
+        /*rightHandClauseList=*/nullptr, atomicCaptureOp);
   }
   firOpBuilder.setInsertionPointToEnd(&block);
   if constexpr (std::is_same<AtomicListT,

diff  --git a/flang/test/Lower/OpenACC/acc-atomic-capture.f90 b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
index 1a5fe8d57c1533a..382991cf7221ba7 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-capture.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
@@ -7,11 +7,11 @@ program acc_atomic_capture_test
 
 !CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
 !CHECK: %[[Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"}
+!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: acc.atomic.capture {
 !CHECK: acc.atomic.read %[[X]] = %[[Y]] : !fir.ref<i32>
 !CHECK: acc.atomic.update %[[Y]] : !fir.ref<i32> {
 !CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: %[[result:.*]] = arith.addi %[[temp]], %[[ARG]] : i32
 !CHECK: acc.yield %[[result]] : i32
 !CHECK: }
@@ -23,10 +23,10 @@ program acc_atomic_capture_test
     !$acc end atomic
 
 
+!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: acc.atomic.capture {
 !CHECK: acc.atomic.update %[[Y]] : !fir.ref<i32> {
 !CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: %[[result:.*]] = arith.muli %[[temp]], %[[ARG]] : i32
 !CHECK: acc.yield %[[result]] : i32
 !CHECK: }
@@ -76,12 +76,12 @@ subroutine pointers_in_atomic_capture()
 !CHECK: %[[loaded_A_addr:.*]] = fir.box_addr %[[loaded_A]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
 !CHECK: %[[loaded_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
 !CHECK: %[[loaded_B_addr:.*]] = fir.box_addr %[[loaded_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
-!CHECK: acc.atomic.capture   {
-!CHECK: acc.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
-!CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK: %[[PRIVATE_LOADED_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
 !CHECK: %[[PRIVATE_LOADED_B_addr:.*]] = fir.box_addr %[[PRIVATE_LOADED_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
 !CHECK: %[[loaded_value:.*]] = fir.load %[[PRIVATE_LOADED_B_addr]] : !fir.ptr<i32>
+!CHECK: acc.atomic.capture   {
+!CHECK: acc.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
+!CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK: %[[result:.*]] = arith.addi %[[ARG]], %[[loaded_value]] : i32
 !CHECK: acc.yield %[[result]] : i32
 !CHECK: }

diff  --git a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90 b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
index 24dd0ee5a8999e4..b2a993ddd825105 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
@@ -14,9 +14,9 @@ subroutine sb
 !CHECK:   %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFsbEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:   %[[Y_REF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFsbEy"}
 !CHECK:   %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_REF]] {uniq_name = "_QFsbEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-!CHECK:   acc.atomic.update   %[[X_DECL]]#0 : !fir.ref<i32> {
+!CHECK:   %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
+!CHECK:   acc.atomic.update   %[[X_DECL]]#1 : !fir.ref<i32> {
 !CHECK:   ^bb0(%[[ARG_X:.*]]: i32):
-!CHECK:     %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
 !CHECK:     %[[X_UPDATE_VAL:.*]] = arith.addi %[[ARG_X]], %[[Y_VAL]] : i32
 !CHECK:     acc.yield %[[X_UPDATE_VAL]] : i32
 !CHECK:   }

diff  --git a/flang/test/Lower/OpenACC/acc-atomic-update.f90 b/flang/test/Lower/OpenACC/acc-atomic-update.f90
index 546d012982e23ba..96e1d5e58e6f482 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-update.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-update.f90
@@ -31,25 +31,25 @@ program acc_atomic_update_test
 !CHECK: %{{.*}} = fir.convert %[[D_ADDR]] : (!fir.ref<i32>) -> !fir.ptr<i32>
 !CHECK: fir.store {{.*}} to %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
 !CHECK: %[[LOADED_A:.*]] = fir.load %[[A_ADDR]] : !fir.ref<!fir.ptr<i32>>
+!CHECK: %[[LOADED_B:.*]] = fir.load %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
+!CHECK: %{{.*}} = fir.load %[[LOADED_B]] : !fir.ptr<i32>
 !CHECK:  acc.atomic.update   %[[LOADED_A]] : !fir.ptr<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_B:.*]] = fir.load %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
-!CHECK:    %{{.*}} = fir.load %[[LOADED_B]] : !fir.ptr<i32>
 !CHECK:    %[[RESULT:.*]] = arith.addi %[[ARG]], %{{.*}} : i32
 !CHECK:    acc.yield %[[RESULT]] : i32
 !CHECK: }
     !$acc atomic update
         a = a + b 
 
+!CHECK: {{.*}} = arith.constant 1 : i32
 !CHECK: acc.atomic.update   %[[Y]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    {{.*}} = arith.constant 1 : i32
 !CHECK:    %[[RESULT:.*]] = arith.addi %[[ARG]], {{.*}} : i32
 !CHECK:    acc.yield %[[RESULT]] : i32
 !CHECK:  }
+!CHECK:  %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK:  acc.atomic.update   %[[Z]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK:    %[[RESULT:.*]] = arith.muli %[[LOADED_X]], %[[ARG]] : i32
 !CHECK:    acc.yield %[[RESULT]] : i32
 !CHECK:  }
@@ -58,10 +58,10 @@ program acc_atomic_update_test
     !$acc atomic update
         z = x * z 
 
+!CHECK:  %[[C1_VAL:.*]] = arith.constant 1 : i32
 !CHECK:  acc.atomic.update   %[[I1]] : !fir.ref<i8> {
 !CHECK:  ^bb0(%[[VAL:.*]]: i8):
 !CHECK:    %[[CVT_VAL:.*]] = fir.convert %[[VAL]] : (i8) -> i32
-!CHECK:    %[[C1_VAL:.*]] = arith.constant 1 : i32
 !CHECK:    %[[ADD_VAL:.*]] = arith.addi %[[CVT_VAL]], %[[C1_VAL]] : i32
 !CHECK:    %[[UPDATED_VAL:.*]] = fir.convert %[[ADD_VAL]] : (i32) -> i8
 !CHECK:    acc.yield %[[UPDATED_VAL]] : i8

diff  --git a/flang/test/Lower/OpenMP/FIR/atomic-capture.f90 b/flang/test/Lower/OpenMP/FIR/atomic-capture.f90
index f48a5efaf4354e7..a8b04a2e90cd46b 100644
--- a/flang/test/Lower/OpenMP/FIR/atomic-capture.f90
+++ b/flang/test/Lower/OpenMP/FIR/atomic-capture.f90
@@ -8,11 +8,11 @@ program OmpAtomicCapture
 
 !CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
 !CHECK: %[[Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"}
+!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: omp.atomic.capture   memory_order(release) {
 !CHECK: omp.atomic.read %[[X]] = %[[Y]] : !fir.ref<i32>
 !CHECK: omp.atomic.update %[[Y]] : !fir.ref<i32> {
 !CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: %[[result:.*]] = arith.addi %[[temp]], %[[ARG]] : i32
 !CHECK: omp.yield(%[[result]] : i32)
 !CHECK: }
@@ -24,10 +24,10 @@ program OmpAtomicCapture
     !$omp end atomic
 
 
+!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: omp.atomic.capture   hint(uncontended) {
 !CHECK: omp.atomic.update %[[Y]] : !fir.ref<i32> {
 !CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: %[[result:.*]] = arith.muli %[[temp]], %[[ARG]] : i32
 !CHECK: omp.yield(%[[result]] : i32)
 !CHECK: }
@@ -94,12 +94,12 @@ subroutine pointers_in_atomic_capture()
 !CHECK: %[[loaded_A_addr:.*]] = fir.box_addr %[[loaded_A]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
 !CHECK: %[[loaded_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
 !CHECK: %[[loaded_B_addr:.*]] = fir.box_addr %[[loaded_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
-!CHECK: omp.atomic.capture   {
-!CHECK: omp.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
-!CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK: %[[PRIVATE_LOADED_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
 !CHECK: %[[PRIVATE_LOADED_B_addr:.*]] = fir.box_addr %[[PRIVATE_LOADED_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
 !CHECK: %[[loaded_value:.*]] = fir.load %[[PRIVATE_LOADED_B_addr]] : !fir.ptr<i32>
+!CHECK: omp.atomic.capture   {
+!CHECK: omp.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
+!CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK: %[[result:.*]] = arith.addi %[[ARG]], %[[loaded_value]] : i32
 !CHECK: omp.yield(%[[result]] : i32)
 !CHECK: }

diff  --git a/flang/test/Lower/OpenMP/FIR/atomic-update.f90 b/flang/test/Lower/OpenMP/FIR/atomic-update.f90
index d0185d2f3b14dfe..56ced10901ab677 100644
--- a/flang/test/Lower/OpenMP/FIR/atomic-update.f90
+++ b/flang/test/Lower/OpenMP/FIR/atomic-update.f90
@@ -32,25 +32,25 @@ program OmpAtomicUpdate
 !CHECK: %{{.*}} = fir.convert %[[D_ADDR]] : (!fir.ref<i32>) -> !fir.ptr<i32>
 !CHECK: fir.store {{.*}} to %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
 !CHECK: %[[LOADED_A:.*]] = fir.load %[[A_ADDR]] : !fir.ref<!fir.ptr<i32>>
+!CHECK: %[[LOADED_B:.*]] = fir.load %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
+!CHECK: %{{.*}} = fir.load %[[LOADED_B]] : !fir.ptr<i32>
 !CHECK:  omp.atomic.update   %[[LOADED_A]] : !fir.ptr<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_B:.*]] = fir.load %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
-!CHECK:    %{{.*}} = fir.load %[[LOADED_B]] : !fir.ptr<i32>
 !CHECK:    %[[RESULT:.*]] = arith.addi %[[ARG]], %{{.*}} : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK: }
     !$omp atomic update
         a = a + b 
 
+!CHECK: {{.*}} = arith.constant 1 : i32
 !CHECK: omp.atomic.update   %[[Y]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    {{.*}} = arith.constant 1 : i32
 !CHECK:    %[[RESULT:.*]] = arith.addi %[[ARG]], {{.*}} : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
+!CHECK:  %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK:  omp.atomic.update   %[[Z]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK:    %[[RESULT:.*]] = arith.muli %[[LOADED_X]], %[[ARG]] : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
@@ -59,9 +59,9 @@ program OmpAtomicUpdate
     !$omp atomic update
         z = x * z 
 
+!CHECK:  %{{.*}} = arith.constant 1 : i32
 !CHECK:  omp.atomic.update   memory_order(relaxed) hint(uncontended) %[[X]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %{{.*}} = arith.constant 1 : i32
 !CHECK:    %[[RESULT:.*]] = arith.subi %[[ARG]], {{.*}} : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
@@ -75,9 +75,9 @@ program OmpAtomicUpdate
 !CHECK:    %[[RESULT:.*]] = arith.select %{{.*}}, %{{.*}}, %[[LOADED_Z]] : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
+!CHECK:  %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK:  omp.atomic.update   memory_order(relaxed) hint(contended) %[[Z]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK:    %[[RESULT:.*]] = arith.addi %[[ARG]], %[[LOADED_X]] : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
@@ -88,15 +88,15 @@ program OmpAtomicUpdate
     !$omp atomic relaxed hint(omp_sync_hint_contended)
         z = z + x
 
+!CHECK:  %{{.*}} = arith.constant 10 : i32
 !CHECK:  omp.atomic.update   memory_order(release) hint(contended) %[[Z]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %{{.*}} = arith.constant 10 : i32
 !CHECK:   %[[RESULT:.*]] = arith.muli {{.*}}, %[[ARG]] : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
+!CHECK:  %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref<i32>
 !CHECK:  omp.atomic.update   memory_order(release) hint(speculative) %[[X]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref<i32>
 !CHECK:    %[[RESULT:.*]] = arith.divsi %[[ARG]], %[[LOADED_Z]] : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
@@ -106,15 +106,15 @@ program OmpAtomicUpdate
     !$omp atomic hint(omp_lock_hint_speculative) update release
         x = x / z
 
+!CHECK:  %{{.*}} = arith.constant 10 : i32
 !CHECK:  omp.atomic.update   memory_order(seq_cst) hint(nonspeculative) %[[Y]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %{{.*}} = arith.constant 10 : i32
 !CHECK:    %[[RESULT:.*]] = arith.addi %{{.*}}, %[[ARG]] : i32
 !CHECK:   omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
+!CHECK:  %[[LOADED_Y:.*]] = fir.load %[[Y]] : !fir.ref<i32>
 !CHECK:  omp.atomic.update   memory_order(seq_cst) %[[Z]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_Y:.*]] = fir.load %[[Y]] : !fir.ref<i32>
 !CHECK:    %[[RESULT:.*]] = arith.addi %[[LOADED_Y]], %[[ARG]] : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
@@ -123,10 +123,10 @@ program OmpAtomicUpdate
     !$omp atomic seq_cst update
         z = y + z
 
+!CHECK:  %[[C1_VAL:.*]] = arith.constant 1 : i32
 !CHECK:  omp.atomic.update   %[[I1]] : !fir.ref<i8> {
 !CHECK:  ^bb0(%[[VAL:.*]]: i8):
 !CHECK:    %[[CVT_VAL:.*]] = fir.convert %[[VAL]] : (i8) -> i32
-!CHECK:    %[[C1_VAL:.*]] = arith.constant 1 : i32
 !CHECK:    %[[ADD_VAL:.*]] = arith.addi %[[CVT_VAL]], %[[C1_VAL]] : i32
 !CHECK:    %[[UPDATED_VAL:.*]] = fir.convert %[[ADD_VAL]] : (i32) -> i8
 !CHECK:    omp.yield(%[[UPDATED_VAL]] : i8)

diff  --git a/flang/test/Lower/OpenMP/atomic-update-hlfir.f90 b/flang/test/Lower/OpenMP/atomic-update-hlfir.f90
index f00ed495ae6f89f..329009ab8ef8e9b 100644
--- a/flang/test/Lower/OpenMP/atomic-update-hlfir.f90
+++ b/flang/test/Lower/OpenMP/atomic-update-hlfir.f90
@@ -14,9 +14,9 @@ subroutine sb
 !CHECK:   %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFsbEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:   %[[Y_REF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFsbEy"}
 !CHECK:   %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_REF]] {uniq_name = "_QFsbEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-!CHECK:   omp.atomic.update   %[[X_DECL]]#0 : !fir.ref<i32> {
+!CHECK:   %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
+!CHECK:   omp.atomic.update   %[[X_DECL]]#1 : !fir.ref<i32> {
 !CHECK:   ^bb0(%[[ARG_X:.*]]: i32):
-!CHECK:     %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
 !CHECK:     %[[X_UPDATE_VAL:.*]] = arith.addi %[[ARG_X]], %[[Y_VAL]] : i32
 !CHECK:     omp.yield(%[[X_UPDATE_VAL]] : i32)
 !CHECK:   }

diff  --git a/flang/test/Lower/OpenMP/common-atomic-lowering.f90 b/flang/test/Lower/OpenMP/common-atomic-lowering.f90
new file mode 100644
index 000000000000000..076091b44b33a71
--- /dev/null
+++ b/flang/test/Lower/OpenMP/common-atomic-lowering.f90
@@ -0,0 +1,74 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+!CHECK: func.func @_QQmain() attributes {fir.bindc_name = "sample"} {
+!CHECK: %[[val_0:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFEa"}
+!CHECK: %[[val_1:.*]]:2 = hlfir.declare %[[val_0]] {uniq_name = "_QFEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[val_2:.*]] = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFEb"}
+!CHECK: %[[val_3:.*]]:2 = hlfir.declare %[[val_2]] {uniq_name = "_QFEb"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[val_4:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
+!CHECK: %[[val_5:.*]]:2 = hlfir.declare %[[val_4]] {uniq_name = "_QFEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[val_c5:.*]] = arith.constant 5 : index
+!CHECK: %[[val_6:.*]] = fir.alloca !fir.array<5xi32> {bindc_name = "y", uniq_name = "_QFEy"}
+!CHECK: %[[val_7:.*]] = fir.shape %[[val_c5]] : (index) -> !fir.shape<1>
+!CHECK: %[[val_8:.*]]:2 = hlfir.declare %[[val_6]](%[[val_7]]) {uniq_name = "_QFEy"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>)
+!CHECK: %[[val_c2:.*]] = arith.constant 2 : index
+!CHECK: %[[val_9:.*]] = hlfir.designate %[[val_8]]#0 (%[[val_c2]])  : (!fir.ref<!fir.array<5xi32>>, index) -> !fir.ref<i32>
+!CHECK: %[[val_c8:.*]] = arith.constant 8 : i32
+!CHECK: %[[val_10:.*]] = fir.load %[[val_5]]#0 : !fir.ref<i32>
+!CHECK: %[[val_11:.*]] = arith.addi %[[val_c8]], %[[val_10]] : i32
+!CHECK: %[[val_12:.*]] = hlfir.no_reassoc %[[val_11]] : i32
+!CHECK: omp.atomic.update %[[val_9]] : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG:.*]]: i32):
+!CHECK:     %[[val_18:.*]] = arith.muli %[[val_12]], %[[ARG]] : i32
+!CHECK:     omp.yield(%[[val_18]] : i32)
+!CHECK: }
+!CHECK: %[[val_c2_0:.*]] = arith.constant 2 : index
+!CHECK: %[[val_13:.*]] = hlfir.designate %[[val_8]]#0 (%[[val_c2_0]])  : (!fir.ref<!fir.array<5xi32>>, index) -> !fir.ref<i32>
+!CHECK: %[[val_c8_1:.*]] = arith.constant 8 : i32
+!CHECK: omp.atomic.update %[[val_13:.*]] : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG:.*]]: i32):
+!CHECK:     %[[val_18:.*]] = arith.divsi %[[ARG]], %[[val_c8_1]] : i32
+!CHECK:     omp.yield(%[[val_18]] : i32)
+!CHECK: }
+!CHECK: %[[val_c8_2:.*]] = arith.constant 8 : i32
+!CHECK: %[[val_c4:.*]] = arith.constant 4 : index
+!CHECK: %[[val_14:.*]] = hlfir.designate %[[val_8]]#0 (%[[val_c4]])  : (!fir.ref<!fir.array<5xi32>>, index) -> !fir.ref<i32>
+!CHECK: %[[val_15:.*]] = fir.load %[[val_14]] : !fir.ref<i32>
+!CHECK: %[[val_16:.*]] = arith.addi %[[val_c8_2]], %[[val_15]] : i32
+!CHECK: %[[val_17:.*]] = hlfir.no_reassoc %[[val_16]] : i32
+!CHECK: omp.atomic.update %[[val_5]]#1 : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG:.*]]: i32):
+!CHECK:      %[[val_18:.*]] = arith.addi %[[val_17]], %[[ARG]] : i32
+!CHECK:      omp.yield(%[[val_18]] : i32)
+!CHECK: }
+!CHECK: %[[val_c8_3:.*]] = arith.constant 8 : i32
+!CHECK: omp.atomic.update %[[val_5]]#1 : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG]]: i32):
+!CHECK:     %[[val_18:.*]] = arith.subi %[[val_c8_3]], %[[ARG]] : i32
+!CHECK:     omp.yield(%[[val_18]] : i32)
+!CHECK:   }
+!CHECK: return
+!CHECK: }
+program sample
+
+  integer :: x
+  integer, dimension(5) :: y
+  integer :: a, b
+
+  !$omp atomic update
+    y(2) =  (8 + x) * y(2)
+  !$omp end atomic
+
+  !$omp atomic update
+    y(2) =  y(2) / 8
+  !$omp end atomic
+
+  !$omp atomic update
+    x =  (8 + y(4)) + x
+  !$omp end atomic
+
+  !$omp atomic update
+    x =  8 - x
+  !$omp end atomic
+
+end program sample


        


More information about the flang-commits mailing list