[flang-commits] [flang] [flang] Allow lowering of sub-expressions to be overridden (PR #69944)

via flang-commits flang-commits at lists.llvm.org
Tue Oct 24 01:40:45 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-openacc

Author: None (jeanPerier)

<details>
<summary>Changes</summary>

OpenACC/OpenMP atomic lowering needs a finer control over expression lowering. This patch allows mapping evaluate::Expr<T> to mlir::Value so that any subsequent expression lowering will use these values when an operand is a mapped Expr<T>.

This is an alternative to https://github.com/llvm/llvm-project/pull/69866 From which I took the test.

The same tests as in https://github.com/llvm/llvm-project/pull/69866 are failing because the "non atomic part" is now out of the atomic.update op, which in some cases is causing verification failures because this is generated in the middle of an omp.atomic.capture. I did not try fixing these failures. My patch is about the lowering infrastructure and how to use it rather than the OpenMP semantics.

---

Patch is 32.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69944.diff


12 Files Affected:

- (modified) flang/include/flang/Lower/AbstractConverter.h (+10) 
- (modified) flang/lib/Lower/Bridge.cpp (+11) 
- (modified) flang/lib/Lower/ConvertExpr.cpp (+17) 
- (modified) flang/lib/Lower/ConvertExprToHLFIR.cpp (+11) 
- (modified) flang/lib/Lower/DirectivesCommon.h (+55-99) 
- (modified) flang/test/Lower/OpenACC/acc-atomic-capture.f90 (+5-5) 
- (modified) flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90 (+2-2) 
- (modified) flang/test/Lower/OpenACC/acc-atomic-update.f90 (+5-5) 
- (modified) flang/test/Lower/OpenMP/FIR/atomic-capture.f90 (+5-5) 
- (modified) flang/test/Lower/OpenMP/FIR/atomic-update.f90 (+11-11) 
- (modified) flang/test/Lower/OpenMP/atomic-update-hlfir.f90 (+2-2) 
- (added) flang/test/Lower/OpenMP/common-atomic-lowering.f90 (+74) 


``````````diff
diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index c792e75f1146499..fa67729fe036684 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -60,6 +60,8 @@ using SomeExpr = Fortran::evaluate::Expr<Fortran::evaluate::SomeType>;
 using SymbolRef = Fortran::common::Reference<const Fortran::semantics::Symbol>;
 class StatementContext;
 
+using ExprToValueMap = llvm::DenseMap<const SomeExpr *, mlir::Value>;
+
 //===----------------------------------------------------------------------===//
 // AbstractConverter interface
 //===----------------------------------------------------------------------===//
@@ -90,6 +92,14 @@ class AbstractConverter {
   /// added or replaced at the inner-most level of the local symbol map.
   virtual void bindSymbol(SymbolRef sym, const fir::ExtendedValue &exval) = 0;
 
+  /// Override lowering of expression with pre-lowered values.
+  /// Associate mlir::Value to evaluate::Expr. All subsequent call to
+  /// genExprXXX() will replace any occurrence of an overridden
+  /// expression in the expression tree by the pre-lowered values.
+  virtual void overrideExprValues(const ExprToValueMap *) = 0;
+  void resetExprOverrides() { overrideExprValues(nullptr); }
+  virtual const ExprToValueMap *getExprOverrides() = 0;
+
   /// Get the label set associated with a symbol.
   virtual bool lookupLabelSet(SymbolRef sym, pft::LabelSet &labelSet) = 0;
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index c3afd91d7453caa..2bfab3a250e6916 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -504,6 +504,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     addSymbol(sym, exval, /*forced=*/true);
   }
 
+  void
+  overrideExprValues(const Fortran::lower::ExprToValueMap *map) override final {
+    exprValueOverrides = map;
+  }
+
+  const Fortran::lower::ExprToValueMap *getExprOverrides() override final {
+    return exprValueOverrides;
+  }
+
   bool lookupLabelSet(Fortran::lower::SymbolRef sym,
                       Fortran::lower::pft::LabelSet &labelSet) override final {
     Fortran::lower::pft::FunctionLikeUnit &owningProc =
@@ -4890,6 +4899,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// Whether an OpenMP target region or declare target function/subroutine
   /// intended for device offloading has been detected
   bool ompDeviceCodeFound = false;
+
+  const Fortran::lower::ExprToValueMap *exprValueOverrides{nullptr};
 };
 
 } // namespace
diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index 6d2ac62b61b74c3..1a2b3856c526716 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -2963,8 +2963,21 @@ class ScalarExprLowering {
     return asArray(x);
   }
 
+  template <typename A>
+  mlir::Value getIfOverridenExpr(const Fortran::evaluate::Expr<A> &x) {
+    if (const Fortran::lower::ExprToValueMap *map =
+            converter.getExprOverrides()) {
+      Fortran::lower::SomeExpr someExpr = toEvExpr(x);
+      if (auto match = map->find(&someExpr); match != map->end())
+        return match->second;
+    }
+    return mlir::Value{};
+  }
+
   template <typename A>
   ExtValue gen(const Fortran::evaluate::Expr<A> &x) {
+    if (mlir::Value val = getIfOverridenExpr(x))
+      return val;
     // Whole array symbols or components, and results of transformational
     // functions already have a storage and the scalar expression lowering path
     // is used to not create a new temporary storage.
@@ -2978,6 +2991,8 @@ class ScalarExprLowering {
   }
   template <typename A>
   ExtValue genval(const Fortran::evaluate::Expr<A> &x) {
+    if (mlir::Value val = getIfOverridenExpr(x))
+      return val;
     if (isScalar(x) || Fortran::evaluate::UnwrapWholeSymbolDataRef(x) ||
         inInitializer)
       return std::visit([&](const auto &e) { return genval(e); }, x.u);
@@ -2987,6 +3002,8 @@ class ScalarExprLowering {
   template <int KIND>
   ExtValue genval(const Fortran::evaluate::Expr<Fortran::evaluate::Type<
                       Fortran::common::TypeCategory::Logical, KIND>> &exp) {
+    if (mlir::Value val = getIfOverridenExpr(exp))
+      return val;
     return std::visit([&](const auto &e) { return genval(e); }, exp.u);
   }
 
diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index 4cf29c9aecbf577..1da6a5bdd54784e 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -1423,6 +1423,17 @@ class HlfirBuilder {
 
   template <typename T>
   hlfir::EntityWithAttributes gen(const Fortran::evaluate::Expr<T> &expr) {
+    if (const Fortran::lower::ExprToValueMap *map =
+            getConverter().getExprOverrides()) {
+      if constexpr (std::is_same_v<T, Fortran::evaluate::SomeType>) {
+        if (auto match = map->find(&expr); match != map->end())
+          return hlfir::EntityWithAttributes{match->second};
+      } else {
+        Fortran::lower::SomeExpr someExpr = toEvExpr(expr);
+        if (auto match = map->find(&someExpr); match != map->end())
+          return hlfir::EntityWithAttributes{match->second};
+      }
+    }
     return std::visit([&](const auto &x) { return gen(x); }, expr.u);
   }
 
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index ed44598bc925212..91801aad9c074a4 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -200,62 +200,13 @@ static inline void genOmpAccAtomicUpdateStatement(
     mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable,
     const Fortran::parser::Expr &assignmentStmtExpr,
     [[maybe_unused]] const AtomicListT *leftHandClauseList,
-    [[maybe_unused]] const AtomicListT *rightHandClauseList) {
+    [[maybe_unused]] const AtomicListT *rightHandClauseList,
+    mlir::Operation *atomicCaptureOp = nullptr) {
   // Generate `omp.atomic.update` operation for atomic assignment statements
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::Location currentLocation = converter.getCurrentLocation();
 
-  const auto *varDesignator =
-      std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>(
-          &assignmentStmtVariable.u);
-  assert(varDesignator && "Variable designator for atomic update assignment "
-                          "statement does not exist");
-  const Fortran::parser::Name *name =
-      Fortran::semantics::getDesignatorNameIfDataRef(varDesignator->value());
-  if (!name)
-    TODO(converter.getCurrentLocation(),
-         "Array references as atomic update variable");
-  assert(name && name->symbol &&
-         "No symbol attached to atomic update variable");
-  if (Fortran::semantics::IsAllocatableOrPointer(name->symbol->GetUltimate()))
-    converter.bindSymbol(*name->symbol, lhsAddr);
-
-  //  Lowering is in two steps :
-  //  subroutine sb
-  //    integer :: a, b
-  //    !$omp atomic update
-  //      a = a + b
-  //  end subroutine
-  //
-  //  1. Lower to scf.execute_region_op
-  //
-  //  func.func @_QPsb() {
-  //    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
-  //    %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
-  //    %2 = scf.execute_region -> i32 {
-  //      %3 = fir.load %0 : !fir.ref<i32>
-  //      %4 = fir.load %1 : !fir.ref<i32>
-  //      %5 = arith.addi %3, %4 : i32
-  //      scf.yield %5 : i32
-  //    }
-  //    return
-  //  }
-  auto tempOp =
-      firOpBuilder.create<mlir::scf::ExecuteRegionOp>(currentLocation, varType);
-  firOpBuilder.createBlock(&tempOp.getRegion());
-  mlir::Block &block = tempOp.getRegion().back();
-  firOpBuilder.setInsertionPointToEnd(&block);
-  Fortran::lower::StatementContext stmtCtx;
-  mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
-      *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx));
-  mlir::Value convertResult =
-      firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
-  // Insert the terminator: YieldOp.
-  firOpBuilder.create<mlir::scf::YieldOp>(currentLocation, convertResult);
-  firOpBuilder.setInsertionPointToStart(&block);
-
-  //  2. Create the omp.atomic.update Operation using the Operations in the
-  //     temporary scf.execute_region Operation.
+  //  Create the omp.atomic.update Operation
   //
   //  func.func @_QPsb() {
   //    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
@@ -269,11 +220,37 @@ static inline void genOmpAccAtomicUpdateStatement(
   //    }
   //    return
   //  }
-  mlir::Value updateVar = converter.getSymbolAddress(*name->symbol);
-  if (auto decl = updateVar.getDefiningOp<hlfir::DeclareOp>())
-    updateVar = decl.getBase();
 
-  firOpBuilder.setInsertionPointAfter(tempOp);
+  Fortran::lower::ExprToValueMap exprValueOverrides;
+  // Lower any non atomic sub-expression before the atomic operation, and
+  // map its lowered value to the semantic representation.
+  const Fortran::lower::SomeExpr *nonAtomicSubExpr{nullptr};
+  std::visit(
+      [&](const auto &op) -> void {
+        using T = std::decay_t<decltype(op)>;
+        if constexpr (std::is_base_of<Fortran::parser::Expr::IntrinsicBinary,
+                                      T>::value) {
+          const auto &exprLeft{std::get<0>(op.t)};
+          const auto &exprRight{std::get<1>(op.t)};
+          if (exprLeft.value().source == assignmentStmtVariable.GetSource())
+            nonAtomicSubExpr = Fortran::semantics::GetExpr(exprRight);
+          else
+            nonAtomicSubExpr = Fortran::semantics::GetExpr(exprLeft);
+        }
+      },
+      assignmentStmtExpr.u);
+  StatementContext nonAtomicStmtCtx;
+  if (nonAtomicSubExpr) {
+    // Generate non atomic part before all the atomic operations.
+    auto insertionPoint = firOpBuilder.saveInsertionPoint();
+    if (atomicCaptureOp)
+      firOpBuilder.setInsertionPoint(atomicCaptureOp);
+    mlir::Value nonAtomicVal = fir::getBase(converter.genExprValue(
+        currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
+    exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
+    if (atomicCaptureOp)
+      firOpBuilder.restoreInsertionPoint(insertionPoint);
+  }
 
   mlir::Operation *atomicUpdateOp = nullptr;
   if constexpr (std::is_same<AtomicListT,
@@ -289,10 +266,10 @@ static inline void genOmpAccAtomicUpdateStatement(
       genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList,
                                             hint, memoryOrder);
     atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>(
-        currentLocation, updateVar, hint, memoryOrder);
+        currentLocation, lhsAddr, hint, memoryOrder);
   } else {
     atomicUpdateOp = firOpBuilder.create<mlir::acc::AtomicUpdateOp>(
-        currentLocation, updateVar);
+        currentLocation, lhsAddr);
   }
 
   llvm::SmallVector<mlir::Type> varTys = {varType};
@@ -301,38 +278,25 @@ static inline void genOmpAccAtomicUpdateStatement(
   mlir::Value val =
       fir::getBase(atomicUpdateOp->getRegion(0).front().getArgument(0));
 
-  llvm::SmallVector<mlir::Operation *> ops;
-  for (mlir::Operation &op : tempOp.getRegion().getOps())
-    ops.push_back(&op);
-
-  // SCF Yield is converted to OMP Yield. All other operations are copied
-  for (mlir::Operation *op : ops) {
-    if (auto y = mlir::dyn_cast<mlir::scf::YieldOp>(op)) {
-      firOpBuilder.setInsertionPointToEnd(
-          &atomicUpdateOp->getRegion(0).front());
-      if constexpr (std::is_same<AtomicListT,
-                                 Fortran::parser::OmpAtomicClauseList>()) {
-        firOpBuilder.create<mlir::omp::YieldOp>(currentLocation,
-                                                y.getResults());
-      } else {
-        firOpBuilder.create<mlir::acc::YieldOp>(currentLocation,
-                                                y.getResults());
-      }
-      op->erase();
+  exprValueOverrides.try_emplace(
+      Fortran::semantics::GetExpr(assignmentStmtVariable), val);
+  {
+    // statement context inside the atomic block.
+    converter.overrideExprValues(&exprValueOverrides);
+    Fortran::lower::StatementContext atomicStmtCtx;
+    mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
+        *Fortran::semantics::GetExpr(assignmentStmtExpr), atomicStmtCtx));
+    mlir::Value convertResult =
+        firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
+    if constexpr (std::is_same<AtomicListT,
+                               Fortran::parser::OmpAtomicClauseList>()) {
+      firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, convertResult);
     } else {
-      op->remove();
-      atomicUpdateOp->getRegion(0).front().push_back(op);
+      firOpBuilder.create<mlir::acc::YieldOp>(currentLocation, convertResult);
     }
+    converter.resetExprOverrides();
   }
-
-  // Remove the load and replace all uses of load with the block argument
-  for (mlir::Operation &op : atomicUpdateOp->getRegion(0).getOps()) {
-    fir::LoadOp y = mlir::dyn_cast<fir::LoadOp>(&op);
-    if (y && y.getMemref() == updateVar)
-      y.getRes().replaceAllUsesWith(val);
-  }
-
-  tempOp.erase();
+  firOpBuilder.setInsertionPointAfter(atomicUpdateOp);
 }
 
 /// Processes an atomic construct with write clause.
@@ -423,11 +387,7 @@ void genOmpAccAtomicUpdate(Fortran::lower::AbstractConverter &converter,
   Fortran::lower::StatementContext stmtCtx;
   mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
-  mlir::Type varType =
-      fir::getBase(
-          converter.genExprValue(
-              *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx))
-          .getType();
+  mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
   genOmpAccAtomicUpdateStatement<AtomicListT>(
       converter, lhsAddr, varType, assignmentStmtVariable, assignmentStmtExpr,
       leftHandClauseList, rightHandClauseList);
@@ -450,11 +410,7 @@ void genOmpAtomic(Fortran::lower::AbstractConverter &converter,
   Fortran::lower::StatementContext stmtCtx;
   mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
-  mlir::Type varType =
-      fir::getBase(
-          converter.genExprValue(
-              *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx))
-          .getType();
+  mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
   // If atomic-clause is not present on the construct, the behaviour is as if
   // the update clause is specified (for both OpenMP and OpenACC).
   genOmpAccAtomicUpdateStatement<AtomicListT>(
@@ -551,7 +507,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
       genOmpAccAtomicUpdateStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt2VarType, stmt2Var, stmt2Expr,
           /*leftHandClauseList=*/nullptr,
-          /*rightHandClauseList=*/nullptr);
+          /*rightHandClauseList=*/nullptr, atomicCaptureOp);
     } else {
       // Atomic capture construct is of the form [capture-stmt, write-stmt]
       const Fortran::semantics::SomeExpr &fromExpr =
@@ -580,7 +536,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
     genOmpAccAtomicUpdateStatement<AtomicListT>(
         converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
         /*leftHandClauseList=*/nullptr,
-        /*rightHandClauseList=*/nullptr);
+        /*rightHandClauseList=*/nullptr, atomicCaptureOp);
   }
   firOpBuilder.setInsertionPointToEnd(&block);
   if constexpr (std::is_same<AtomicListT,
diff --git a/flang/test/Lower/OpenACC/acc-atomic-capture.f90 b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
index 1a5fe8d57c1533a..382991cf7221ba7 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-capture.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
@@ -7,11 +7,11 @@ program acc_atomic_capture_test
 
 !CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
 !CHECK: %[[Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"}
+!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: acc.atomic.capture {
 !CHECK: acc.atomic.read %[[X]] = %[[Y]] : !fir.ref<i32>
 !CHECK: acc.atomic.update %[[Y]] : !fir.ref<i32> {
 !CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: %[[result:.*]] = arith.addi %[[temp]], %[[ARG]] : i32
 !CHECK: acc.yield %[[result]] : i32
 !CHECK: }
@@ -23,10 +23,10 @@ program acc_atomic_capture_test
     !$acc end atomic
 
 
+!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: acc.atomic.capture {
 !CHECK: acc.atomic.update %[[Y]] : !fir.ref<i32> {
 !CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: %[[result:.*]] = arith.muli %[[temp]], %[[ARG]] : i32
 !CHECK: acc.yield %[[result]] : i32
 !CHECK: }
@@ -76,12 +76,12 @@ subroutine pointers_in_atomic_capture()
 !CHECK: %[[loaded_A_addr:.*]] = fir.box_addr %[[loaded_A]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
 !CHECK: %[[loaded_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
 !CHECK: %[[loaded_B_addr:.*]] = fir.box_addr %[[loaded_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
-!CHECK: acc.atomic.capture   {
-!CHECK: acc.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
-!CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK: %[[PRIVATE_LOADED_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
 !CHECK: %[[PRIVATE_LOADED_B_addr:.*]] = fir.box_addr %[[PRIVATE_LOADED_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
 !CHECK: %[[loaded_value:.*]] = fir.load %[[PRIVATE_LOADED_B_addr]] : !fir.ptr<i32>
+!CHECK: acc.atomic.capture   {
+!CHECK: acc.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
+!CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK: %[[result:.*]] = arith.addi %[[ARG]], %[[loaded_value]] : i32
 !CHECK: acc.yield %[[result]] : i32
 !CHECK: }
diff --git a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90 b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
index 24dd0ee5a8999e4..b2a993ddd825105 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
@@ -14,9 +14,9 @@ subroutine sb
 !CHECK:   %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFsbEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:   %[[Y_REF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFsbEy"}
 !CHECK:   %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_REF]] {uniq_name = "_QFsbEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-!CHECK:   acc.atomic.update   %[[X_DECL]]#0 : !fir.ref<i32> {
+!CHECK:   %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
+!CHECK:   acc.atomic.update   %[[X_DECL]]#1 : !fir.ref<i32> {
 !CHECK:   ^bb0(%[[ARG_X:.*]]: i32):
-!CHECK:     %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
 !CHECK:     %[[X_UPDATE_VAL:.*]] = arith.addi %[[ARG_X]], %[[Y_VAL]] : i32
 !CHECK:     acc.yield %[[X_UPDATE_VAL]] : i32
 !CHECK:   }
diff --git a/flang/test/Lower/OpenACC/acc-atomic-update.f90 b/flang/test/Lower/OpenACC/acc-atomic-update.f90
index 546d012982e23ba..96e1d5e58e6f482 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-update.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-update.f90
@@ -31,25 +31,25 @@ program acc_atomic_update_test
 !CHECK: %{{.*}} = fir.convert %[[D_ADDR]] : (!fir.ref<i32>) -> !fir.ptr<i32>
 !CHECK: fir.store {{.*}} to %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
 !CHECK: %[[LOADED_A:.*]] = fir.load %[[A_ADDR]] : !fir.ref<!fir.ptr<i32>>
+!CHECK: %[[LOADED_B:.*]] = fir.load %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
+!CHECK: %{{.*}} = fir.load %[[LOADED_B]] : !fir.ptr<i32>
 !CHECK:  acc.atomic.update   %[[LOADED_A]] : !fir.ptr<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_B:.*]] = fir.load %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
-!CHECK:    %{{.*}} = fir.load %[[LOADED_B]] : !fir.ptr<i32>
 !CHECK:    %[[RESULT:.*]] = arith.addi %[[ARG]], %{{.*}} : i32
 !CHECK:    acc.yield %[[RESULT]] : i32
 !CHECK: }
     !$acc atomic update
         a = a + b 
 
+!CHECK: {{.*}} = arith.constant 1 : i32
 !CHECK: acc.atomic.update   %[[Y]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    {{.*}} = arith.constant 1 : i32
 !CHECK:    %[[RESULT:.*]] = arith.addi %[[ARG]], {{.*}} : i32
 !CHECK:    acc.yiel...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/69944


More information about the flang-commits mailing list