[flang-commits] [flang] [flang] Allow lowering of sub-expressions to be overridden (PR #69944)

via flang-commits flang-commits at lists.llvm.org
Tue Oct 24 01:22:22 PDT 2023


https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/69944

>From 887f368b106bc66fdbbd2ecf82d83a0d88372ed4 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Mon, 23 Oct 2023 09:27:05 -0700
Subject: [PATCH 1/2] [flang] Allow lowering of sub-expressions to be
 overridden

OpenACC/OpenMP atomic lowering needs a finer control over expression
lowering. This patch allows mapping evaluate::Expr<T> to mlir::Value
so that any subsequent expression lowering will use these values when
an operand is a mapped Expr<T>.

This is an alternative to https://github.com/llvm/llvm-project/pull/69866
>From which I took the test.

The same test as in https://github.com/llvm/llvm-project/pull/69866 are
failing because the "non atomic part" is now out of the atomic.update
op, which in some case causing verification failure because this is
generated in the middle of an omp.atomic.capture. I did not try fixing
these failures. My patch is about the lowering infrastructure and how to
use it rather than the OpenMP semantics.

Co-authored-by: Nimish Mishra <neelam.nimish at gmail.com>
---
 flang/include/flang/Lower/AbstractConverter.h |  10 ++
 flang/lib/Lower/Bridge.cpp                    |  11 ++
 flang/lib/Lower/ConvertExpr.cpp               |  17 +++
 flang/lib/Lower/ConvertExprToHLFIR.cpp        |  11 ++
 flang/lib/Lower/DirectivesCommon.h            | 141 ++++++------------
 .../Lower/OpenACC/acc-atomic-update-hlfir.f90 |   2 +-
 .../test/Lower/OpenMP/atomic-update-hlfir.f90 |   4 +-
 .../Lower/OpenMP/common-atomic-lowering.f90   |  74 +++++++++
 8 files changed, 171 insertions(+), 99 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/common-atomic-lowering.f90

diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index c792e75f1146499..fa67729fe036684 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -60,6 +60,8 @@ using SomeExpr = Fortran::evaluate::Expr<Fortran::evaluate::SomeType>;
 using SymbolRef = Fortran::common::Reference<const Fortran::semantics::Symbol>;
 class StatementContext;
 
+using ExprToValueMap = llvm::DenseMap<const SomeExpr *, mlir::Value>;
+
 //===----------------------------------------------------------------------===//
 // AbstractConverter interface
 //===----------------------------------------------------------------------===//
@@ -90,6 +92,14 @@ class AbstractConverter {
   /// added or replaced at the inner-most level of the local symbol map.
   virtual void bindSymbol(SymbolRef sym, const fir::ExtendedValue &exval) = 0;
 
+  /// Override lowering of expression with pre-lowered values.
+  /// Associate mlir::Value to evaluate::Expr. All subsequent call to
+  /// genExprXXX() will replace any occurrence of an overridden
+  /// expression in the expression tree by the pre-lowered values.
+  virtual void overrideExprValues(const ExprToValueMap *) = 0;
+  void resetExprOverrides() { overrideExprValues(nullptr); }
+  virtual const ExprToValueMap *getExprOverrides() = 0;
+
   /// Get the label set associated with a symbol.
   virtual bool lookupLabelSet(SymbolRef sym, pft::LabelSet &labelSet) = 0;
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index c3afd91d7453caa..2bfab3a250e6916 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -504,6 +504,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     addSymbol(sym, exval, /*forced=*/true);
   }
 
+  void
+  overrideExprValues(const Fortran::lower::ExprToValueMap *map) override final {
+    exprValueOverrides = map;
+  }
+
+  const Fortran::lower::ExprToValueMap *getExprOverrides() override final {
+    return exprValueOverrides;
+  }
+
   bool lookupLabelSet(Fortran::lower::SymbolRef sym,
                       Fortran::lower::pft::LabelSet &labelSet) override final {
     Fortran::lower::pft::FunctionLikeUnit &owningProc =
@@ -4890,6 +4899,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// Whether an OpenMP target region or declare target function/subroutine
   /// intended for device offloading has been detected
   bool ompDeviceCodeFound = false;
+
+  const Fortran::lower::ExprToValueMap *exprValueOverrides{nullptr};
 };
 
 } // namespace
diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index 6d2ac62b61b74c3..1a2b3856c526716 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -2963,8 +2963,21 @@ class ScalarExprLowering {
     return asArray(x);
   }
 
+  template <typename A>
+  mlir::Value getIfOverridenExpr(const Fortran::evaluate::Expr<A> &x) {
+    if (const Fortran::lower::ExprToValueMap *map =
+            converter.getExprOverrides()) {
+      Fortran::lower::SomeExpr someExpr = toEvExpr(x);
+      if (auto match = map->find(&someExpr); match != map->end())
+        return match->second;
+    }
+    return mlir::Value{};
+  }
+
   template <typename A>
   ExtValue gen(const Fortran::evaluate::Expr<A> &x) {
+    if (mlir::Value val = getIfOverridenExpr(x))
+      return val;
     // Whole array symbols or components, and results of transformational
     // functions already have a storage and the scalar expression lowering path
     // is used to not create a new temporary storage.
@@ -2978,6 +2991,8 @@ class ScalarExprLowering {
   }
   template <typename A>
   ExtValue genval(const Fortran::evaluate::Expr<A> &x) {
+    if (mlir::Value val = getIfOverridenExpr(x))
+      return val;
     if (isScalar(x) || Fortran::evaluate::UnwrapWholeSymbolDataRef(x) ||
         inInitializer)
       return std::visit([&](const auto &e) { return genval(e); }, x.u);
@@ -2987,6 +3002,8 @@ class ScalarExprLowering {
   template <int KIND>
   ExtValue genval(const Fortran::evaluate::Expr<Fortran::evaluate::Type<
                       Fortran::common::TypeCategory::Logical, KIND>> &exp) {
+    if (mlir::Value val = getIfOverridenExpr(exp))
+      return val;
     return std::visit([&](const auto &e) { return genval(e); }, exp.u);
   }
 
diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index 4cf29c9aecbf577..1da6a5bdd54784e 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -1423,6 +1423,17 @@ class HlfirBuilder {
 
   template <typename T>
   hlfir::EntityWithAttributes gen(const Fortran::evaluate::Expr<T> &expr) {
+    if (const Fortran::lower::ExprToValueMap *map =
+            getConverter().getExprOverrides()) {
+      if constexpr (std::is_same_v<T, Fortran::evaluate::SomeType>) {
+        if (auto match = map->find(&expr); match != map->end())
+          return hlfir::EntityWithAttributes{match->second};
+      } else {
+        Fortran::lower::SomeExpr someExpr = toEvExpr(expr);
+        if (auto match = map->find(&someExpr); match != map->end())
+          return hlfir::EntityWithAttributes{match->second};
+      }
+    }
     return std::visit([&](const auto &x) { return gen(x); }, expr.u);
   }
 
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index ed44598bc925212..68109315edba40c 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -205,57 +205,7 @@ static inline void genOmpAccAtomicUpdateStatement(
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::Location currentLocation = converter.getCurrentLocation();
 
-  const auto *varDesignator =
-      std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>(
-          &assignmentStmtVariable.u);
-  assert(varDesignator && "Variable designator for atomic update assignment "
-                          "statement does not exist");
-  const Fortran::parser::Name *name =
-      Fortran::semantics::getDesignatorNameIfDataRef(varDesignator->value());
-  if (!name)
-    TODO(converter.getCurrentLocation(),
-         "Array references as atomic update variable");
-  assert(name && name->symbol &&
-         "No symbol attached to atomic update variable");
-  if (Fortran::semantics::IsAllocatableOrPointer(name->symbol->GetUltimate()))
-    converter.bindSymbol(*name->symbol, lhsAddr);
-
-  //  Lowering is in two steps :
-  //  subroutine sb
-  //    integer :: a, b
-  //    !$omp atomic update
-  //      a = a + b
-  //  end subroutine
-  //
-  //  1. Lower to scf.execute_region_op
-  //
-  //  func.func @_QPsb() {
-  //    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
-  //    %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
-  //    %2 = scf.execute_region -> i32 {
-  //      %3 = fir.load %0 : !fir.ref<i32>
-  //      %4 = fir.load %1 : !fir.ref<i32>
-  //      %5 = arith.addi %3, %4 : i32
-  //      scf.yield %5 : i32
-  //    }
-  //    return
-  //  }
-  auto tempOp =
-      firOpBuilder.create<mlir::scf::ExecuteRegionOp>(currentLocation, varType);
-  firOpBuilder.createBlock(&tempOp.getRegion());
-  mlir::Block &block = tempOp.getRegion().back();
-  firOpBuilder.setInsertionPointToEnd(&block);
-  Fortran::lower::StatementContext stmtCtx;
-  mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
-      *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx));
-  mlir::Value convertResult =
-      firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
-  // Insert the terminator: YieldOp.
-  firOpBuilder.create<mlir::scf::YieldOp>(currentLocation, convertResult);
-  firOpBuilder.setInsertionPointToStart(&block);
-
-  //  2. Create the omp.atomic.update Operation using the Operations in the
-  //     temporary scf.execute_region Operation.
+  //  Create the omp.atomic.update Operation
   //
   //  func.func @_QPsb() {
   //    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
@@ -269,11 +219,31 @@ static inline void genOmpAccAtomicUpdateStatement(
   //    }
   //    return
   //  }
-  mlir::Value updateVar = converter.getSymbolAddress(*name->symbol);
-  if (auto decl = updateVar.getDefiningOp<hlfir::DeclareOp>())
-    updateVar = decl.getBase();
 
-  firOpBuilder.setInsertionPointAfter(tempOp);
+  Fortran::lower::ExprToValueMap exprValueOverrides;
+  // Lower any non atomic sub-expression before the atomic operation, and
+  // map its lowered value to the semantic representation.
+  const Fortran::lower::SomeExpr *nonAtomicSubExpr{nullptr};
+  std::visit(
+      [&](const auto &op) -> void {
+        using T = std::decay_t<decltype(op)>;
+        if constexpr (std::is_base_of<Fortran::parser::Expr::IntrinsicBinary,
+                                      T>::value) {
+          const auto &exprLeft{std::get<0>(op.t)};
+          const auto &exprRight{std::get<1>(op.t)};
+          if (exprLeft.value().source == assignmentStmtVariable.GetSource())
+            nonAtomicSubExpr = Fortran::semantics::GetExpr(exprRight);
+          else
+            nonAtomicSubExpr = Fortran::semantics::GetExpr(exprLeft);
+        }
+      },
+      assignmentStmtExpr.u);
+  StatementContext nonAtomicStmtCtx;
+  if (nonAtomicSubExpr) {
+    mlir::Value nonAtomicVal = fir::getBase(converter.genExprValue(
+        currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
+    exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
+  }
 
   mlir::Operation *atomicUpdateOp = nullptr;
   if constexpr (std::is_same<AtomicListT,
@@ -289,10 +259,10 @@ static inline void genOmpAccAtomicUpdateStatement(
       genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList,
                                             hint, memoryOrder);
     atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>(
-        currentLocation, updateVar, hint, memoryOrder);
+        currentLocation, lhsAddr, hint, memoryOrder);
   } else {
     atomicUpdateOp = firOpBuilder.create<mlir::acc::AtomicUpdateOp>(
-        currentLocation, updateVar);
+        currentLocation, lhsAddr);
   }
 
   llvm::SmallVector<mlir::Type> varTys = {varType};
@@ -301,38 +271,25 @@ static inline void genOmpAccAtomicUpdateStatement(
   mlir::Value val =
       fir::getBase(atomicUpdateOp->getRegion(0).front().getArgument(0));
 
-  llvm::SmallVector<mlir::Operation *> ops;
-  for (mlir::Operation &op : tempOp.getRegion().getOps())
-    ops.push_back(&op);
-
-  // SCF Yield is converted to OMP Yield. All other operations are copied
-  for (mlir::Operation *op : ops) {
-    if (auto y = mlir::dyn_cast<mlir::scf::YieldOp>(op)) {
-      firOpBuilder.setInsertionPointToEnd(
-          &atomicUpdateOp->getRegion(0).front());
-      if constexpr (std::is_same<AtomicListT,
-                                 Fortran::parser::OmpAtomicClauseList>()) {
-        firOpBuilder.create<mlir::omp::YieldOp>(currentLocation,
-                                                y.getResults());
-      } else {
-        firOpBuilder.create<mlir::acc::YieldOp>(currentLocation,
-                                                y.getResults());
-      }
-      op->erase();
+  exprValueOverrides.try_emplace(
+      Fortran::semantics::GetExpr(assignmentStmtVariable), val);
+  {
+    // statement context inside the atomic block.
+    converter.overrideExprValues(&exprValueOverrides);
+    Fortran::lower::StatementContext atomicStmtCtx;
+    mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
+        *Fortran::semantics::GetExpr(assignmentStmtExpr), atomicStmtCtx));
+    mlir::Value convertResult =
+        firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
+    if constexpr (std::is_same<AtomicListT,
+                               Fortran::parser::OmpAtomicClauseList>()) {
+      firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, convertResult);
     } else {
-      op->remove();
-      atomicUpdateOp->getRegion(0).front().push_back(op);
+      firOpBuilder.create<mlir::acc::YieldOp>(currentLocation, convertResult);
     }
+    converter.resetExprOverrides();
   }
-
-  // Remove the load and replace all uses of load with the block argument
-  for (mlir::Operation &op : atomicUpdateOp->getRegion(0).getOps()) {
-    fir::LoadOp y = mlir::dyn_cast<fir::LoadOp>(&op);
-    if (y && y.getMemref() == updateVar)
-      y.getRes().replaceAllUsesWith(val);
-  }
-
-  tempOp.erase();
+  firOpBuilder.setInsertionPointAfter(atomicUpdateOp);
 }
 
 /// Processes an atomic construct with write clause.
@@ -423,11 +380,7 @@ void genOmpAccAtomicUpdate(Fortran::lower::AbstractConverter &converter,
   Fortran::lower::StatementContext stmtCtx;
   mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
-  mlir::Type varType =
-      fir::getBase(
-          converter.genExprValue(
-              *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx))
-          .getType();
+  mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
   genOmpAccAtomicUpdateStatement<AtomicListT>(
       converter, lhsAddr, varType, assignmentStmtVariable, assignmentStmtExpr,
       leftHandClauseList, rightHandClauseList);
@@ -450,11 +403,7 @@ void genOmpAtomic(Fortran::lower::AbstractConverter &converter,
   Fortran::lower::StatementContext stmtCtx;
   mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
-  mlir::Type varType =
-      fir::getBase(
-          converter.genExprValue(
-              *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx))
-          .getType();
+  mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
   // If atomic-clause is not present on the construct, the behaviour is as if
   // the update clause is specified (for both OpenMP and OpenACC).
   genOmpAccAtomicUpdateStatement<AtomicListT>(
diff --git a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90 b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
index 24dd0ee5a8999e4..675373f5d03d81f 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
@@ -14,7 +14,7 @@ subroutine sb
 !CHECK:   %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFsbEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:   %[[Y_REF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFsbEy"}
 !CHECK:   %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_REF]] {uniq_name = "_QFsbEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-!CHECK:   acc.atomic.update   %[[X_DECL]]#0 : !fir.ref<i32> {
+!CHECK:   acc.atomic.update   %[[X_DECL]]#1 : !fir.ref<i32> {
 !CHECK:   ^bb0(%[[ARG_X:.*]]: i32):
 !CHECK:     %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
 !CHECK:     %[[X_UPDATE_VAL:.*]] = arith.addi %[[ARG_X]], %[[Y_VAL]] : i32
diff --git a/flang/test/Lower/OpenMP/atomic-update-hlfir.f90 b/flang/test/Lower/OpenMP/atomic-update-hlfir.f90
index f00ed495ae6f89f..329009ab8ef8e9b 100644
--- a/flang/test/Lower/OpenMP/atomic-update-hlfir.f90
+++ b/flang/test/Lower/OpenMP/atomic-update-hlfir.f90
@@ -14,9 +14,9 @@ subroutine sb
 !CHECK:   %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFsbEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:   %[[Y_REF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFsbEy"}
 !CHECK:   %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_REF]] {uniq_name = "_QFsbEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-!CHECK:   omp.atomic.update   %[[X_DECL]]#0 : !fir.ref<i32> {
+!CHECK:   %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
+!CHECK:   omp.atomic.update   %[[X_DECL]]#1 : !fir.ref<i32> {
 !CHECK:   ^bb0(%[[ARG_X:.*]]: i32):
-!CHECK:     %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
 !CHECK:     %[[X_UPDATE_VAL:.*]] = arith.addi %[[ARG_X]], %[[Y_VAL]] : i32
 !CHECK:     omp.yield(%[[X_UPDATE_VAL]] : i32)
 !CHECK:   }
diff --git a/flang/test/Lower/OpenMP/common-atomic-lowering.f90 b/flang/test/Lower/OpenMP/common-atomic-lowering.f90
new file mode 100644
index 000000000000000..076091b44b33a71
--- /dev/null
+++ b/flang/test/Lower/OpenMP/common-atomic-lowering.f90
@@ -0,0 +1,74 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+!CHECK: func.func @_QQmain() attributes {fir.bindc_name = "sample"} {
+!CHECK: %[[val_0:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFEa"}
+!CHECK: %[[val_1:.*]]:2 = hlfir.declare %[[val_0]] {uniq_name = "_QFEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[val_2:.*]] = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFEb"}
+!CHECK: %[[val_3:.*]]:2 = hlfir.declare %[[val_2]] {uniq_name = "_QFEb"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[val_4:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
+!CHECK: %[[val_5:.*]]:2 = hlfir.declare %[[val_4]] {uniq_name = "_QFEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[val_c5:.*]] = arith.constant 5 : index
+!CHECK: %[[val_6:.*]] = fir.alloca !fir.array<5xi32> {bindc_name = "y", uniq_name = "_QFEy"}
+!CHECK: %[[val_7:.*]] = fir.shape %[[val_c5]] : (index) -> !fir.shape<1>
+!CHECK: %[[val_8:.*]]:2 = hlfir.declare %[[val_6]](%[[val_7]]) {uniq_name = "_QFEy"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>)
+!CHECK: %[[val_c2:.*]] = arith.constant 2 : index
+!CHECK: %[[val_9:.*]] = hlfir.designate %[[val_8]]#0 (%[[val_c2]])  : (!fir.ref<!fir.array<5xi32>>, index) -> !fir.ref<i32>
+!CHECK: %[[val_c8:.*]] = arith.constant 8 : i32
+!CHECK: %[[val_10:.*]] = fir.load %[[val_5]]#0 : !fir.ref<i32>
+!CHECK: %[[val_11:.*]] = arith.addi %[[val_c8]], %[[val_10]] : i32
+!CHECK: %[[val_12:.*]] = hlfir.no_reassoc %[[val_11]] : i32
+!CHECK: omp.atomic.update %[[val_9]] : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG:.*]]: i32):
+!CHECK:     %[[val_18:.*]] = arith.muli %[[val_12]], %[[ARG]] : i32
+!CHECK:     omp.yield(%[[val_18]] : i32)
+!CHECK: }
+!CHECK: %[[val_c2_0:.*]] = arith.constant 2 : index
+!CHECK: %[[val_13:.*]] = hlfir.designate %[[val_8]]#0 (%[[val_c2_0]])  : (!fir.ref<!fir.array<5xi32>>, index) -> !fir.ref<i32>
+!CHECK: %[[val_c8_1:.*]] = arith.constant 8 : i32
+!CHECK: omp.atomic.update %[[val_13:.*]] : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG:.*]]: i32):
+!CHECK:     %[[val_18:.*]] = arith.divsi %[[ARG]], %[[val_c8_1]] : i32
+!CHECK:     omp.yield(%[[val_18]] : i32)
+!CHECK: }
+!CHECK: %[[val_c8_2:.*]] = arith.constant 8 : i32
+!CHECK: %[[val_c4:.*]] = arith.constant 4 : index
+!CHECK: %[[val_14:.*]] = hlfir.designate %[[val_8]]#0 (%[[val_c4]])  : (!fir.ref<!fir.array<5xi32>>, index) -> !fir.ref<i32>
+!CHECK: %[[val_15:.*]] = fir.load %[[val_14]] : !fir.ref<i32>
+!CHECK: %[[val_16:.*]] = arith.addi %[[val_c8_2]], %[[val_15]] : i32
+!CHECK: %[[val_17:.*]] = hlfir.no_reassoc %[[val_16]] : i32
+!CHECK: omp.atomic.update %[[val_5]]#1 : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG:.*]]: i32):
+!CHECK:      %[[val_18:.*]] = arith.addi %[[val_17]], %[[ARG]] : i32
+!CHECK:      omp.yield(%[[val_18]] : i32)
+!CHECK: }
+!CHECK: %[[val_c8_3:.*]] = arith.constant 8 : i32
+!CHECK: omp.atomic.update %[[val_5]]#1 : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG]]: i32):
+!CHECK:     %[[val_18:.*]] = arith.subi %[[val_c8_3]], %[[ARG]] : i32
+!CHECK:     omp.yield(%[[val_18]] : i32)
+!CHECK:   }
+!CHECK: return
+!CHECK: }
+program sample
+
+  integer :: x
+  integer, dimension(5) :: y
+  integer :: a, b
+
+  !$omp atomic update
+    y(2) =  (8 + x) * y(2)
+  !$omp end atomic
+
+  !$omp atomic update
+    y(2) =  y(2) / 8
+  !$omp end atomic
+
+  !$omp atomic update
+    x =  (8 + y(4)) + x
+  !$omp end atomic
+
+  !$omp atomic update
+    x =  8 - x
+  !$omp end atomic
+
+end program sample

>From 3e94f5e57cfef1414a09341ae31c0ddb87b16f60 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 24 Oct 2023 01:18:31 -0700
Subject: [PATCH 2/2] Generate non atomic expr before atomic.capture and fix
 tests

---
 flang/lib/Lower/DirectivesCommon.h            | 13 ++++++++---
 .../test/Lower/OpenACC/acc-atomic-capture.f90 | 10 ++++-----
 .../Lower/OpenACC/acc-atomic-update-hlfir.f90 |  2 +-
 .../test/Lower/OpenACC/acc-atomic-update.f90  | 10 ++++-----
 .../test/Lower/OpenMP/FIR/atomic-capture.f90  | 10 ++++-----
 flang/test/Lower/OpenMP/FIR/atomic-update.f90 | 22 +++++++++----------
 6 files changed, 37 insertions(+), 30 deletions(-)

diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 68109315edba40c..91801aad9c074a4 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -200,7 +200,8 @@ static inline void genOmpAccAtomicUpdateStatement(
     mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable,
     const Fortran::parser::Expr &assignmentStmtExpr,
     [[maybe_unused]] const AtomicListT *leftHandClauseList,
-    [[maybe_unused]] const AtomicListT *rightHandClauseList) {
+    [[maybe_unused]] const AtomicListT *rightHandClauseList,
+    mlir::Operation *atomicCaptureOp = nullptr) {
   // Generate `omp.atomic.update` operation for atomic assignment statements
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::Location currentLocation = converter.getCurrentLocation();
@@ -240,9 +241,15 @@ static inline void genOmpAccAtomicUpdateStatement(
       assignmentStmtExpr.u);
   StatementContext nonAtomicStmtCtx;
   if (nonAtomicSubExpr) {
+    // Generate non atomic part before all the atomic operations.
+    auto insertionPoint = firOpBuilder.saveInsertionPoint();
+    if (atomicCaptureOp)
+      firOpBuilder.setInsertionPoint(atomicCaptureOp);
     mlir::Value nonAtomicVal = fir::getBase(converter.genExprValue(
         currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
     exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
+    if (atomicCaptureOp)
+      firOpBuilder.restoreInsertionPoint(insertionPoint);
   }
 
   mlir::Operation *atomicUpdateOp = nullptr;
@@ -500,7 +507,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
       genOmpAccAtomicUpdateStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt2VarType, stmt2Var, stmt2Expr,
           /*leftHandClauseList=*/nullptr,
-          /*rightHandClauseList=*/nullptr);
+          /*rightHandClauseList=*/nullptr, atomicCaptureOp);
     } else {
       // Atomic capture construct is of the form [capture-stmt, write-stmt]
       const Fortran::semantics::SomeExpr &fromExpr =
@@ -529,7 +536,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
     genOmpAccAtomicUpdateStatement<AtomicListT>(
         converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
         /*leftHandClauseList=*/nullptr,
-        /*rightHandClauseList=*/nullptr);
+        /*rightHandClauseList=*/nullptr, atomicCaptureOp);
   }
   firOpBuilder.setInsertionPointToEnd(&block);
   if constexpr (std::is_same<AtomicListT,
diff --git a/flang/test/Lower/OpenACC/acc-atomic-capture.f90 b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
index 1a5fe8d57c1533a..382991cf7221ba7 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-capture.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
@@ -7,11 +7,11 @@ program acc_atomic_capture_test
 
 !CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
 !CHECK: %[[Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"}
+!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: acc.atomic.capture {
 !CHECK: acc.atomic.read %[[X]] = %[[Y]] : !fir.ref<i32>
 !CHECK: acc.atomic.update %[[Y]] : !fir.ref<i32> {
 !CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: %[[result:.*]] = arith.addi %[[temp]], %[[ARG]] : i32
 !CHECK: acc.yield %[[result]] : i32
 !CHECK: }
@@ -23,10 +23,10 @@ program acc_atomic_capture_test
     !$acc end atomic
 
 
+!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: acc.atomic.capture {
 !CHECK: acc.atomic.update %[[Y]] : !fir.ref<i32> {
 !CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: %[[result:.*]] = arith.muli %[[temp]], %[[ARG]] : i32
 !CHECK: acc.yield %[[result]] : i32
 !CHECK: }
@@ -76,12 +76,12 @@ subroutine pointers_in_atomic_capture()
 !CHECK: %[[loaded_A_addr:.*]] = fir.box_addr %[[loaded_A]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
 !CHECK: %[[loaded_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
 !CHECK: %[[loaded_B_addr:.*]] = fir.box_addr %[[loaded_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
-!CHECK: acc.atomic.capture   {
-!CHECK: acc.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
-!CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK: %[[PRIVATE_LOADED_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
 !CHECK: %[[PRIVATE_LOADED_B_addr:.*]] = fir.box_addr %[[PRIVATE_LOADED_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
 !CHECK: %[[loaded_value:.*]] = fir.load %[[PRIVATE_LOADED_B_addr]] : !fir.ptr<i32>
+!CHECK: acc.atomic.capture   {
+!CHECK: acc.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
+!CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK: %[[result:.*]] = arith.addi %[[ARG]], %[[loaded_value]] : i32
 !CHECK: acc.yield %[[result]] : i32
 !CHECK: }
diff --git a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90 b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
index 675373f5d03d81f..b2a993ddd825105 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
@@ -14,9 +14,9 @@ subroutine sb
 !CHECK:   %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFsbEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:   %[[Y_REF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFsbEy"}
 !CHECK:   %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_REF]] {uniq_name = "_QFsbEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:   %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
 !CHECK:   acc.atomic.update   %[[X_DECL]]#1 : !fir.ref<i32> {
 !CHECK:   ^bb0(%[[ARG_X:.*]]: i32):
-!CHECK:     %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
 !CHECK:     %[[X_UPDATE_VAL:.*]] = arith.addi %[[ARG_X]], %[[Y_VAL]] : i32
 !CHECK:     acc.yield %[[X_UPDATE_VAL]] : i32
 !CHECK:   }
diff --git a/flang/test/Lower/OpenACC/acc-atomic-update.f90 b/flang/test/Lower/OpenACC/acc-atomic-update.f90
index 546d012982e23ba..96e1d5e58e6f482 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-update.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-update.f90
@@ -31,25 +31,25 @@ program acc_atomic_update_test
 !CHECK: %{{.*}} = fir.convert %[[D_ADDR]] : (!fir.ref<i32>) -> !fir.ptr<i32>
 !CHECK: fir.store {{.*}} to %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
 !CHECK: %[[LOADED_A:.*]] = fir.load %[[A_ADDR]] : !fir.ref<!fir.ptr<i32>>
+!CHECK: %[[LOADED_B:.*]] = fir.load %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
+!CHECK: %{{.*}} = fir.load %[[LOADED_B]] : !fir.ptr<i32>
 !CHECK:  acc.atomic.update   %[[LOADED_A]] : !fir.ptr<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_B:.*]] = fir.load %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
-!CHECK:    %{{.*}} = fir.load %[[LOADED_B]] : !fir.ptr<i32>
 !CHECK:    %[[RESULT:.*]] = arith.addi %[[ARG]], %{{.*}} : i32
 !CHECK:    acc.yield %[[RESULT]] : i32
 !CHECK: }
     !$acc atomic update
         a = a + b 
 
+!CHECK: {{.*}} = arith.constant 1 : i32
 !CHECK: acc.atomic.update   %[[Y]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    {{.*}} = arith.constant 1 : i32
 !CHECK:    %[[RESULT:.*]] = arith.addi %[[ARG]], {{.*}} : i32
 !CHECK:    acc.yield %[[RESULT]] : i32
 !CHECK:  }
+!CHECK:  %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK:  acc.atomic.update   %[[Z]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK:    %[[RESULT:.*]] = arith.muli %[[LOADED_X]], %[[ARG]] : i32
 !CHECK:    acc.yield %[[RESULT]] : i32
 !CHECK:  }
@@ -58,10 +58,10 @@ program acc_atomic_update_test
     !$acc atomic update
         z = x * z 
 
+!CHECK:  %[[C1_VAL:.*]] = arith.constant 1 : i32
 !CHECK:  acc.atomic.update   %[[I1]] : !fir.ref<i8> {
 !CHECK:  ^bb0(%[[VAL:.*]]: i8):
 !CHECK:    %[[CVT_VAL:.*]] = fir.convert %[[VAL]] : (i8) -> i32
-!CHECK:    %[[C1_VAL:.*]] = arith.constant 1 : i32
 !CHECK:    %[[ADD_VAL:.*]] = arith.addi %[[CVT_VAL]], %[[C1_VAL]] : i32
 !CHECK:    %[[UPDATED_VAL:.*]] = fir.convert %[[ADD_VAL]] : (i32) -> i8
 !CHECK:    acc.yield %[[UPDATED_VAL]] : i8
diff --git a/flang/test/Lower/OpenMP/FIR/atomic-capture.f90 b/flang/test/Lower/OpenMP/FIR/atomic-capture.f90
index f48a5efaf4354e7..a8b04a2e90cd46b 100644
--- a/flang/test/Lower/OpenMP/FIR/atomic-capture.f90
+++ b/flang/test/Lower/OpenMP/FIR/atomic-capture.f90
@@ -8,11 +8,11 @@ program OmpAtomicCapture
 
 !CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
 !CHECK: %[[Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"}
+!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: omp.atomic.capture   memory_order(release) {
 !CHECK: omp.atomic.read %[[X]] = %[[Y]] : !fir.ref<i32>
 !CHECK: omp.atomic.update %[[Y]] : !fir.ref<i32> {
 !CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: %[[result:.*]] = arith.addi %[[temp]], %[[ARG]] : i32
 !CHECK: omp.yield(%[[result]] : i32)
 !CHECK: }
@@ -24,10 +24,10 @@ program OmpAtomicCapture
     !$omp end atomic
 
 
+!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: omp.atomic.capture   hint(uncontended) {
 !CHECK: omp.atomic.update %[[Y]] : !fir.ref<i32> {
 !CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: %[[result:.*]] = arith.muli %[[temp]], %[[ARG]] : i32
 !CHECK: omp.yield(%[[result]] : i32)
 !CHECK: }
@@ -94,12 +94,12 @@ subroutine pointers_in_atomic_capture()
 !CHECK: %[[loaded_A_addr:.*]] = fir.box_addr %[[loaded_A]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
 !CHECK: %[[loaded_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
 !CHECK: %[[loaded_B_addr:.*]] = fir.box_addr %[[loaded_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
-!CHECK: omp.atomic.capture   {
-!CHECK: omp.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
-!CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK: %[[PRIVATE_LOADED_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
 !CHECK: %[[PRIVATE_LOADED_B_addr:.*]] = fir.box_addr %[[PRIVATE_LOADED_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
 !CHECK: %[[loaded_value:.*]] = fir.load %[[PRIVATE_LOADED_B_addr]] : !fir.ptr<i32>
+!CHECK: omp.atomic.capture   {
+!CHECK: omp.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
+!CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK: %[[result:.*]] = arith.addi %[[ARG]], %[[loaded_value]] : i32
 !CHECK: omp.yield(%[[result]] : i32)
 !CHECK: }
diff --git a/flang/test/Lower/OpenMP/FIR/atomic-update.f90 b/flang/test/Lower/OpenMP/FIR/atomic-update.f90
index d0185d2f3b14dfe..56ced10901ab677 100644
--- a/flang/test/Lower/OpenMP/FIR/atomic-update.f90
+++ b/flang/test/Lower/OpenMP/FIR/atomic-update.f90
@@ -32,25 +32,25 @@ program OmpAtomicUpdate
 !CHECK: %{{.*}} = fir.convert %[[D_ADDR]] : (!fir.ref<i32>) -> !fir.ptr<i32>
 !CHECK: fir.store {{.*}} to %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
 !CHECK: %[[LOADED_A:.*]] = fir.load %[[A_ADDR]] : !fir.ref<!fir.ptr<i32>>
+!CHECK: %[[LOADED_B:.*]] = fir.load %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
+!CHECK: %{{.*}} = fir.load %[[LOADED_B]] : !fir.ptr<i32>
 !CHECK:  omp.atomic.update   %[[LOADED_A]] : !fir.ptr<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_B:.*]] = fir.load %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
-!CHECK:    %{{.*}} = fir.load %[[LOADED_B]] : !fir.ptr<i32>
 !CHECK:    %[[RESULT:.*]] = arith.addi %[[ARG]], %{{.*}} : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK: }
     !$omp atomic update
         a = a + b 
 
+!CHECK: {{.*}} = arith.constant 1 : i32
 !CHECK: omp.atomic.update   %[[Y]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    {{.*}} = arith.constant 1 : i32
 !CHECK:    %[[RESULT:.*]] = arith.addi %[[ARG]], {{.*}} : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
+!CHECK:  %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK:  omp.atomic.update   %[[Z]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK:    %[[RESULT:.*]] = arith.muli %[[LOADED_X]], %[[ARG]] : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
@@ -59,9 +59,9 @@ program OmpAtomicUpdate
     !$omp atomic update
         z = x * z 
 
+!CHECK:  %{{.*}} = arith.constant 1 : i32
 !CHECK:  omp.atomic.update   memory_order(relaxed) hint(uncontended) %[[X]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %{{.*}} = arith.constant 1 : i32
 !CHECK:    %[[RESULT:.*]] = arith.subi %[[ARG]], {{.*}} : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
@@ -75,9 +75,9 @@ program OmpAtomicUpdate
 !CHECK:    %[[RESULT:.*]] = arith.select %{{.*}}, %{{.*}}, %[[LOADED_Z]] : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
+!CHECK:  %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK:  omp.atomic.update   memory_order(relaxed) hint(contended) %[[Z]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK:    %[[RESULT:.*]] = arith.addi %[[ARG]], %[[LOADED_X]] : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
@@ -88,15 +88,15 @@ program OmpAtomicUpdate
     !$omp atomic relaxed hint(omp_sync_hint_contended)
         z = z + x
 
+!CHECK:  %{{.*}} = arith.constant 10 : i32
 !CHECK:  omp.atomic.update   memory_order(release) hint(contended) %[[Z]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %{{.*}} = arith.constant 10 : i32
 !CHECK:   %[[RESULT:.*]] = arith.muli {{.*}}, %[[ARG]] : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
+!CHECK:  %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref<i32>
 !CHECK:  omp.atomic.update   memory_order(release) hint(speculative) %[[X]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref<i32>
 !CHECK:    %[[RESULT:.*]] = arith.divsi %[[ARG]], %[[LOADED_Z]] : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
@@ -106,15 +106,15 @@ program OmpAtomicUpdate
     !$omp atomic hint(omp_lock_hint_speculative) update release
         x = x / z
 
+!CHECK:  %{{.*}} = arith.constant 10 : i32
 !CHECK:  omp.atomic.update   memory_order(seq_cst) hint(nonspeculative) %[[Y]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %{{.*}} = arith.constant 10 : i32
 !CHECK:    %[[RESULT:.*]] = arith.addi %{{.*}}, %[[ARG]] : i32
 !CHECK:   omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
+!CHECK:  %[[LOADED_Y:.*]] = fir.load %[[Y]] : !fir.ref<i32>
 !CHECK:  omp.atomic.update   memory_order(seq_cst) %[[Z]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_Y:.*]] = fir.load %[[Y]] : !fir.ref<i32>
 !CHECK:    %[[RESULT:.*]] = arith.addi %[[LOADED_Y]], %[[ARG]] : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
@@ -123,10 +123,10 @@ program OmpAtomicUpdate
     !$omp atomic seq_cst update
         z = y + z
 
+!CHECK:  %[[C1_VAL:.*]] = arith.constant 1 : i32
 !CHECK:  omp.atomic.update   %[[I1]] : !fir.ref<i8> {
 !CHECK:  ^bb0(%[[VAL:.*]]: i8):
 !CHECK:    %[[CVT_VAL:.*]] = fir.convert %[[VAL]] : (i8) -> i32
-!CHECK:    %[[C1_VAL:.*]] = arith.constant 1 : i32
 !CHECK:    %[[ADD_VAL:.*]] = arith.addi %[[CVT_VAL]], %[[C1_VAL]] : i32
 !CHECK:    %[[UPDATED_VAL:.*]] = fir.convert %[[ADD_VAL]] : (i32) -> i8
 !CHECK:    omp.yield(%[[UPDATED_VAL]] : i8)



More information about the flang-commits mailing list