[flang-commits] [flang] 9ea5254 - [flang][OpenACC][OpenMP] Separate implementations of ATOMIC constructs (#137517)

via flang-commits flang-commits at lists.llvm.org
Mon Apr 28 13:43:43 PDT 2025


Author: Krzysztof Parzyszek
Date: 2025-04-28T15:43:39-05:00
New Revision: 9ea5254f77ae5d5fe8e81f8e39b5d461cc95e9dc

URL: https://github.com/llvm/llvm-project/commit/9ea5254f77ae5d5fe8e81f8e39b5d461cc95e9dc
DIFF: https://github.com/llvm/llvm-project/commit/9ea5254f77ae5d5fe8e81f8e39b5d461cc95e9dc.diff

LOG: [flang][OpenACC][OpenMP] Separate implementations of ATOMIC constructs (#137517)

The OpenMP implementation of the ATOMIC construct will change in the
near future to accommodate atomic conditional-update and conditional-
update-capture operations. This patch separates the shared implemen-
tations to avoid interfering with OpenACC.

Added: 
    

Modified: 
    flang/include/flang/Lower/DirectivesCommon.h
    flang/lib/Lower/OpenACC.cpp
    flang/lib/Lower/OpenMP/OpenMP.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Lower/DirectivesCommon.h b/flang/include/flang/Lower/DirectivesCommon.h
index d1dbaefcd81d0..93ab2e350d035 100644
--- a/flang/include/flang/Lower/DirectivesCommon.h
+++ b/flang/include/flang/Lower/DirectivesCommon.h
@@ -46,520 +46,6 @@
 namespace Fortran {
 namespace lower {
 
-/// Populates \p hint and \p memoryOrder with appropriate clause information
-/// if present on atomic construct.
-static inline void genOmpAtomicHintAndMemoryOrderClauses(
-    Fortran::lower::AbstractConverter &converter,
-    const Fortran::parser::OmpAtomicClauseList &clauseList,
-    mlir::IntegerAttr &hint,
-    mlir::omp::ClauseMemoryOrderKindAttr &memoryOrder) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  for (const Fortran::parser::OmpAtomicClause &clause : clauseList.v) {
-    common::visit(
-        common::visitors{
-            [&](const parser::OmpMemoryOrderClause &s) {
-              auto kind = common::visit(
-                  common::visitors{
-                      [&](const parser::OmpClause::AcqRel &) {
-                        return mlir::omp::ClauseMemoryOrderKind::Acq_rel;
-                      },
-                      [&](const parser::OmpClause::Acquire &) {
-                        return mlir::omp::ClauseMemoryOrderKind::Acquire;
-                      },
-                      [&](const parser::OmpClause::Relaxed &) {
-                        return mlir::omp::ClauseMemoryOrderKind::Relaxed;
-                      },
-                      [&](const parser::OmpClause::Release &) {
-                        return mlir::omp::ClauseMemoryOrderKind::Release;
-                      },
-                      [&](const parser::OmpClause::SeqCst &) {
-                        return mlir::omp::ClauseMemoryOrderKind::Seq_cst;
-                      },
-                      [&](auto &&) -> mlir::omp::ClauseMemoryOrderKind {
-                        llvm_unreachable("Unexpected clause");
-                      },
-                  },
-                  s.v.u);
-              memoryOrder = mlir::omp::ClauseMemoryOrderKindAttr::get(
-                  firOpBuilder.getContext(), kind);
-            },
-            [&](const parser::OmpHintClause &s) {
-              const auto *expr = Fortran::semantics::GetExpr(s.v);
-              uint64_t hintExprValue = *Fortran::evaluate::ToInt64(*expr);
-              hint = firOpBuilder.getI64IntegerAttr(hintExprValue);
-            },
-            [&](const parser::OmpFailClause &) {},
-        },
-        clause.u);
-  }
-}
-
-template <typename AtomicListT>
-static void processOmpAtomicTODO(mlir::Type elementType,
-                                 [[maybe_unused]] mlir::Location loc) {
-  if (!elementType)
-    return;
-  if constexpr (std::is_same<AtomicListT,
-                             Fortran::parser::OmpAtomicClauseList>()) {
-    assert(fir::isa_trivial(fir::unwrapRefType(elementType)) &&
-           "is supported type for omp atomic");
-  }
-}
-
-/// Used to generate atomic.read operation which is created in existing
-/// location set by builder.
-template <typename AtomicListT>
-static inline void genOmpAccAtomicCaptureStatement(
-    Fortran::lower::AbstractConverter &converter, mlir::Value fromAddress,
-    mlir::Value toAddress,
-    [[maybe_unused]] const AtomicListT *leftHandClauseList,
-    [[maybe_unused]] const AtomicListT *rightHandClauseList,
-    mlir::Type elementType, mlir::Location loc) {
-  // Generate `atomic.read` operation for atomic assigment statements
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
-  processOmpAtomicTODO<AtomicListT>(elementType, loc);
-
-  if constexpr (std::is_same<AtomicListT,
-                             Fortran::parser::OmpAtomicClauseList>()) {
-    // If no hint clause is specified, the effect is as if
-    // hint(omp_sync_hint_none) had been specified.
-    mlir::IntegerAttr hint = nullptr;
-
-    mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr;
-    if (leftHandClauseList)
-      genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList,
-                                            hint, memoryOrder);
-    if (rightHandClauseList)
-      genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList,
-                                            hint, memoryOrder);
-    firOpBuilder.create<mlir::omp::AtomicReadOp>(
-        loc, fromAddress, toAddress, mlir::TypeAttr::get(elementType), hint,
-        memoryOrder);
-  } else {
-    firOpBuilder.create<mlir::acc::AtomicReadOp>(
-        loc, fromAddress, toAddress, mlir::TypeAttr::get(elementType));
-  }
-}
-
-/// Used to generate atomic.write operation which is created in existing
-/// location set by builder.
-template <typename AtomicListT>
-static inline void genOmpAccAtomicWriteStatement(
-    Fortran::lower::AbstractConverter &converter, mlir::Value lhsAddr,
-    mlir::Value rhsExpr, [[maybe_unused]] const AtomicListT *leftHandClauseList,
-    [[maybe_unused]] const AtomicListT *rightHandClauseList, mlir::Location loc,
-    mlir::Value *evaluatedExprValue = nullptr) {
-  // Generate `atomic.write` operation for atomic assignment statements
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
-  mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
-  // Create a conversion outside the capture block.
-  auto insertionPoint = firOpBuilder.saveInsertionPoint();
-  firOpBuilder.setInsertionPointAfter(rhsExpr.getDefiningOp());
-  rhsExpr = firOpBuilder.createConvert(loc, varType, rhsExpr);
-  firOpBuilder.restoreInsertionPoint(insertionPoint);
-
-  processOmpAtomicTODO<AtomicListT>(varType, loc);
-
-  if constexpr (std::is_same<AtomicListT,
-                             Fortran::parser::OmpAtomicClauseList>()) {
-    // If no hint clause is specified, the effect is as if
-    // hint(omp_sync_hint_none) had been specified.
-    mlir::IntegerAttr hint = nullptr;
-    mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr;
-    if (leftHandClauseList)
-      genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList,
-                                            hint, memoryOrder);
-    if (rightHandClauseList)
-      genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList,
-                                            hint, memoryOrder);
-    firOpBuilder.create<mlir::omp::AtomicWriteOp>(loc, lhsAddr, rhsExpr, hint,
-                                                  memoryOrder);
-  } else {
-    firOpBuilder.create<mlir::acc::AtomicWriteOp>(loc, lhsAddr, rhsExpr);
-  }
-}
-
-/// Used to generate atomic.update operation which is created in existing
-/// location set by builder.
-template <typename AtomicListT>
-static inline void genOmpAccAtomicUpdateStatement(
-    Fortran::lower::AbstractConverter &converter, mlir::Value lhsAddr,
-    mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable,
-    const Fortran::parser::Expr &assignmentStmtExpr,
-    [[maybe_unused]] const AtomicListT *leftHandClauseList,
-    [[maybe_unused]] const AtomicListT *rightHandClauseList, mlir::Location loc,
-    mlir::Operation *atomicCaptureOp = nullptr) {
-  // Generate `atomic.update` operation for atomic assignment statements
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  mlir::Location currentLocation = converter.getCurrentLocation();
-
-  //  Create the omp.atomic.update or acc.atomic.update operation
-  //
-  //  func.func @_QPsb() {
-  //    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
-  //    %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
-  //    %2 = fir.load %1 : !fir.ref<i32>
-  //    omp.atomic.update   %0 : !fir.ref<i32> {
-  //    ^bb0(%arg0: i32):
-  //      %3 = arith.addi %arg0, %2 : i32
-  //      omp.yield(%3 : i32)
-  //    }
-  //    return
-  //  }
-
-  auto getArgExpression =
-      [](std::list<parser::ActualArgSpec>::const_iterator it) {
-        const auto &arg{std::get<parser::ActualArg>((*it).t)};
-        const auto *parserExpr{
-            std::get_if<common::Indirection<parser::Expr>>(&arg.u)};
-        return parserExpr;
-      };
-
-  // Lower any non atomic sub-expression before the atomic operation, and
-  // map its lowered value to the semantic representation.
-  Fortran::lower::ExprToValueMap exprValueOverrides;
-  // Max and min intrinsics can have a list of Args. Hence we need a list
-  // of nonAtomicSubExprs to hoist. Currently, only the load is hoisted.
-  llvm::SmallVector<const Fortran::lower::SomeExpr *> nonAtomicSubExprs;
-  Fortran::common::visit(
-      Fortran::common::visitors{
-          [&](const common::Indirection<parser::FunctionReference> &funcRef)
-              -> void {
-            const auto &args{std::get<std::list<parser::ActualArgSpec>>(
-                funcRef.value().v.t)};
-            std::list<parser::ActualArgSpec>::const_iterator beginIt =
-                args.begin();
-            std::list<parser::ActualArgSpec>::const_iterator endIt = args.end();
-            const auto *exprFirst{getArgExpression(beginIt)};
-            if (exprFirst && exprFirst->value().source ==
-                                 assignmentStmtVariable.GetSource()) {
-              // Add everything except the first
-              beginIt++;
-            } else {
-              // Add everything except the last
-              endIt--;
-            }
-            std::list<parser::ActualArgSpec>::const_iterator it;
-            for (it = beginIt; it != endIt; it++) {
-              const common::Indirection<parser::Expr> *expr =
-                  getArgExpression(it);
-              if (expr)
-                nonAtomicSubExprs.push_back(Fortran::semantics::GetExpr(*expr));
-            }
-          },
-          [&](const auto &op) -> void {
-            using T = std::decay_t<decltype(op)>;
-            if constexpr (std::is_base_of<
-                              Fortran::parser::Expr::IntrinsicBinary,
-                              T>::value) {
-              const auto &exprLeft{std::get<0>(op.t)};
-              const auto &exprRight{std::get<1>(op.t)};
-              if (exprLeft.value().source == assignmentStmtVariable.GetSource())
-                nonAtomicSubExprs.push_back(
-                    Fortran::semantics::GetExpr(exprRight));
-              else
-                nonAtomicSubExprs.push_back(
-                    Fortran::semantics::GetExpr(exprLeft));
-            }
-          },
-      },
-      assignmentStmtExpr.u);
-  StatementContext nonAtomicStmtCtx;
-  if (!nonAtomicSubExprs.empty()) {
-    // Generate non atomic part before all the atomic operations.
-    auto insertionPoint = firOpBuilder.saveInsertionPoint();
-    if (atomicCaptureOp)
-      firOpBuilder.setInsertionPoint(atomicCaptureOp);
-    mlir::Value nonAtomicVal;
-    for (auto *nonAtomicSubExpr : nonAtomicSubExprs) {
-      nonAtomicVal = fir::getBase(converter.genExprValue(
-          currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
-      exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
-    }
-    if (atomicCaptureOp)
-      firOpBuilder.restoreInsertionPoint(insertionPoint);
-  }
-
-  mlir::Operation *atomicUpdateOp = nullptr;
-  if constexpr (std::is_same<AtomicListT,
-                             Fortran::parser::OmpAtomicClauseList>()) {
-    // If no hint clause is specified, the effect is as if
-    // hint(omp_sync_hint_none) had been specified.
-    mlir::IntegerAttr hint = nullptr;
-    mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr;
-    if (leftHandClauseList)
-      genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList,
-                                            hint, memoryOrder);
-    if (rightHandClauseList)
-      genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList,
-                                            hint, memoryOrder);
-    atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>(
-        currentLocation, lhsAddr, hint, memoryOrder);
-  } else {
-    atomicUpdateOp = firOpBuilder.create<mlir::acc::AtomicUpdateOp>(
-        currentLocation, lhsAddr);
-  }
-
-  processOmpAtomicTODO<AtomicListT>(varType, loc);
-
-  llvm::SmallVector<mlir::Type> varTys = {varType};
-  llvm::SmallVector<mlir::Location> locs = {currentLocation};
-  firOpBuilder.createBlock(&atomicUpdateOp->getRegion(0), {}, varTys, locs);
-  mlir::Value val =
-      fir::getBase(atomicUpdateOp->getRegion(0).front().getArgument(0));
-
-  exprValueOverrides.try_emplace(
-      Fortran::semantics::GetExpr(assignmentStmtVariable), val);
-  {
-    // statement context inside the atomic block.
-    converter.overrideExprValues(&exprValueOverrides);
-    Fortran::lower::StatementContext atomicStmtCtx;
-    mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
-        *Fortran::semantics::GetExpr(assignmentStmtExpr), atomicStmtCtx));
-    mlir::Value convertResult =
-        firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
-    if constexpr (std::is_same<AtomicListT,
-                               Fortran::parser::OmpAtomicClauseList>()) {
-      firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, convertResult);
-    } else {
-      firOpBuilder.create<mlir::acc::YieldOp>(currentLocation, convertResult);
-    }
-    converter.resetExprOverrides();
-  }
-  firOpBuilder.setInsertionPointAfter(atomicUpdateOp);
-}
-
-/// Processes an atomic construct with write clause.
-template <typename AtomicT, typename AtomicListT>
-void genOmpAccAtomicWrite(Fortran::lower::AbstractConverter &converter,
-                          const AtomicT &atomicWrite, mlir::Location loc) {
-  const AtomicListT *rightHandClauseList = nullptr;
-  const AtomicListT *leftHandClauseList = nullptr;
-  if constexpr (std::is_same<AtomicListT,
-                             Fortran::parser::OmpAtomicClauseList>()) {
-    // Get the address of atomic read operands.
-    rightHandClauseList = &std::get<2>(atomicWrite.t);
-    leftHandClauseList = &std::get<0>(atomicWrite.t);
-  }
-
-  const Fortran::parser::AssignmentStmt &stmt =
-      std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
-          atomicWrite.t)
-          .statement;
-  const Fortran::evaluate::Assignment &assign = *stmt.typedAssignment->v;
-  Fortran::lower::StatementContext stmtCtx;
-  // Get the value and address of atomic write operands.
-  mlir::Value rhsExpr =
-      fir::getBase(converter.genExprValue(assign.rhs, stmtCtx));
-  mlir::Value lhsAddr =
-      fir::getBase(converter.genExprAddr(assign.lhs, stmtCtx));
-  genOmpAccAtomicWriteStatement(converter, lhsAddr, rhsExpr, leftHandClauseList,
-                                rightHandClauseList, loc);
-}
-
-/// Processes an atomic construct with read clause.
-template <typename AtomicT, typename AtomicListT>
-void genOmpAccAtomicRead(Fortran::lower::AbstractConverter &converter,
-                         const AtomicT &atomicRead, mlir::Location loc) {
-  const AtomicListT *rightHandClauseList = nullptr;
-  const AtomicListT *leftHandClauseList = nullptr;
-  if constexpr (std::is_same<AtomicListT,
-                             Fortran::parser::OmpAtomicClauseList>()) {
-    // Get the address of atomic read operands.
-    rightHandClauseList = &std::get<2>(atomicRead.t);
-    leftHandClauseList = &std::get<0>(atomicRead.t);
-  }
-
-  const auto &assignmentStmtExpr = std::get<Fortran::parser::Expr>(
-      std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
-          atomicRead.t)
-          .statement.t);
-  const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>(
-      std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
-          atomicRead.t)
-          .statement.t);
-
-  Fortran::lower::StatementContext stmtCtx;
-  const Fortran::semantics::SomeExpr &fromExpr =
-      *Fortran::semantics::GetExpr(assignmentStmtExpr);
-  mlir::Type elementType = converter.genType(fromExpr);
-  mlir::Value fromAddress =
-      fir::getBase(converter.genExprAddr(fromExpr, stmtCtx));
-  mlir::Value toAddress = fir::getBase(converter.genExprAddr(
-      *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
-  genOmpAccAtomicCaptureStatement(converter, fromAddress, toAddress,
-                                  leftHandClauseList, rightHandClauseList,
-                                  elementType, loc);
-}
-
-/// Processes an atomic construct with update clause.
-template <typename AtomicT, typename AtomicListT>
-void genOmpAccAtomicUpdate(Fortran::lower::AbstractConverter &converter,
-                           const AtomicT &atomicUpdate, mlir::Location loc) {
-  const AtomicListT *rightHandClauseList = nullptr;
-  const AtomicListT *leftHandClauseList = nullptr;
-  if constexpr (std::is_same<AtomicListT,
-                             Fortran::parser::OmpAtomicClauseList>()) {
-    // Get the address of atomic read operands.
-    rightHandClauseList = &std::get<2>(atomicUpdate.t);
-    leftHandClauseList = &std::get<0>(atomicUpdate.t);
-  }
-
-  const auto &assignmentStmtExpr = std::get<Fortran::parser::Expr>(
-      std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
-          atomicUpdate.t)
-          .statement.t);
-  const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>(
-      std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
-          atomicUpdate.t)
-          .statement.t);
-
-  Fortran::lower::StatementContext stmtCtx;
-  mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
-      *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
-  mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
-  genOmpAccAtomicUpdateStatement<AtomicListT>(
-      converter, lhsAddr, varType, assignmentStmtVariable, assignmentStmtExpr,
-      leftHandClauseList, rightHandClauseList, loc);
-}
-
-/// Processes an atomic construct with no clause - which implies update clause.
-template <typename AtomicT, typename AtomicListT>
-void genOmpAtomic(Fortran::lower::AbstractConverter &converter,
-                  const AtomicT &atomicConstruct, mlir::Location loc) {
-  const AtomicListT &atomicClauseList =
-      std::get<AtomicListT>(atomicConstruct.t);
-  const auto &assignmentStmtExpr = std::get<Fortran::parser::Expr>(
-      std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
-          atomicConstruct.t)
-          .statement.t);
-  const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>(
-      std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
-          atomicConstruct.t)
-          .statement.t);
-  Fortran::lower::StatementContext stmtCtx;
-  mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
-      *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
-  mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
-  // If atomic-clause is not present on the construct, the behaviour is as if
-  // the update clause is specified (for both OpenMP and OpenACC).
-  genOmpAccAtomicUpdateStatement<AtomicListT>(
-      converter, lhsAddr, varType, assignmentStmtVariable, assignmentStmtExpr,
-      &atomicClauseList, nullptr, loc);
-}
-
-/// Processes an atomic construct with capture clause.
-template <typename AtomicT, typename AtomicListT>
-void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
-                            const AtomicT &atomicCapture, mlir::Location loc) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
-  const Fortran::parser::AssignmentStmt &stmt1 =
-      std::get<typename AtomicT::Stmt1>(atomicCapture.t).v.statement;
-  const Fortran::evaluate::Assignment &assign1 = *stmt1.typedAssignment->v;
-  const auto &stmt1Var{std::get<Fortran::parser::Variable>(stmt1.t)};
-  const auto &stmt1Expr{std::get<Fortran::parser::Expr>(stmt1.t)};
-  const Fortran::parser::AssignmentStmt &stmt2 =
-      std::get<typename AtomicT::Stmt2>(atomicCapture.t).v.statement;
-  const Fortran::evaluate::Assignment &assign2 = *stmt2.typedAssignment->v;
-  const auto &stmt2Var{std::get<Fortran::parser::Variable>(stmt2.t)};
-  const auto &stmt2Expr{std::get<Fortran::parser::Expr>(stmt2.t)};
-
-  // Pre-evaluate expressions to be used in the various operations inside
-  // `atomic.capture` since it is not desirable to have anything other than
-  // a `atomic.read`, `atomic.write`, or `atomic.update` operation
-  // inside `atomic.capture`
-  Fortran::lower::StatementContext stmtCtx;
-  // LHS evaluations are common to all combinations of `atomic.capture`
-  mlir::Value stmt1LHSArg =
-      fir::getBase(converter.genExprAddr(assign1.lhs, stmtCtx));
-  mlir::Value stmt2LHSArg =
-      fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
-
-  // Type information used in generation of `atomic.update` operation
-  mlir::Type stmt1VarType =
-      fir::getBase(converter.genExprValue(assign1.lhs, stmtCtx)).getType();
-  mlir::Type stmt2VarType =
-      fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType();
-
-  mlir::Operation *atomicCaptureOp = nullptr;
-  if constexpr (std::is_same<AtomicListT,
-                             Fortran::parser::OmpAtomicClauseList>()) {
-    mlir::IntegerAttr hint = nullptr;
-    mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr;
-    const AtomicListT &rightHandClauseList = std::get<2>(atomicCapture.t);
-    const AtomicListT &leftHandClauseList = std::get<0>(atomicCapture.t);
-    genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint,
-                                          memoryOrder);
-    genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint,
-                                          memoryOrder);
-    atomicCaptureOp =
-        firOpBuilder.create<mlir::omp::AtomicCaptureOp>(loc, hint, memoryOrder);
-  } else {
-    atomicCaptureOp = firOpBuilder.create<mlir::acc::AtomicCaptureOp>(loc);
-  }
-
-  firOpBuilder.createBlock(&(atomicCaptureOp->getRegion(0)));
-  mlir::Block &block = atomicCaptureOp->getRegion(0).back();
-  firOpBuilder.setInsertionPointToStart(&block);
-  if (Fortran::semantics::checkForSingleVariableOnRHS(stmt1)) {
-    if (Fortran::semantics::checkForSymbolMatch(stmt2)) {
-      // Atomic capture construct is of the form [capture-stmt, update-stmt]
-      const Fortran::semantics::SomeExpr &fromExpr =
-          *Fortran::semantics::GetExpr(stmt1Expr);
-      mlir::Type elementType = converter.genType(fromExpr);
-      genOmpAccAtomicCaptureStatement<AtomicListT>(
-          converter, stmt2LHSArg, stmt1LHSArg,
-          /*leftHandClauseList=*/nullptr,
-          /*rightHandClauseList=*/nullptr, elementType, loc);
-      genOmpAccAtomicUpdateStatement<AtomicListT>(
-          converter, stmt2LHSArg, stmt2VarType, stmt2Var, stmt2Expr,
-          /*leftHandClauseList=*/nullptr,
-          /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
-    } else {
-      // Atomic capture construct is of the form [capture-stmt, write-stmt]
-      firOpBuilder.setInsertionPoint(atomicCaptureOp);
-      mlir::Value stmt2RHSArg =
-          fir::getBase(converter.genExprValue(assign2.rhs, stmtCtx));
-      firOpBuilder.setInsertionPointToStart(&block);
-      const Fortran::semantics::SomeExpr &fromExpr =
-          *Fortran::semantics::GetExpr(stmt1Expr);
-      mlir::Type elementType = converter.genType(fromExpr);
-      genOmpAccAtomicCaptureStatement<AtomicListT>(
-          converter, stmt2LHSArg, stmt1LHSArg,
-          /*leftHandClauseList=*/nullptr,
-          /*rightHandClauseList=*/nullptr, elementType, loc);
-      genOmpAccAtomicWriteStatement<AtomicListT>(
-          converter, stmt2LHSArg, stmt2RHSArg,
-          /*leftHandClauseList=*/nullptr,
-          /*rightHandClauseList=*/nullptr, loc);
-    }
-  } else {
-    // Atomic capture construct is of the form [update-stmt, capture-stmt]
-    const Fortran::semantics::SomeExpr &fromExpr =
-        *Fortran::semantics::GetExpr(stmt2Expr);
-    mlir::Type elementType = converter.genType(fromExpr);
-    genOmpAccAtomicUpdateStatement<AtomicListT>(
-        converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
-        /*leftHandClauseList=*/nullptr,
-        /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
-    genOmpAccAtomicCaptureStatement<AtomicListT>(
-        converter, stmt1LHSArg, stmt2LHSArg,
-        /*leftHandClauseList=*/nullptr,
-        /*rightHandClauseList=*/nullptr, elementType, loc);
-  }
-  firOpBuilder.setInsertionPointToEnd(&block);
-  if constexpr (std::is_same<AtomicListT,
-                             Fortran::parser::OmpAtomicClauseList>()) {
-    firOpBuilder.create<mlir::omp::TerminatorOp>(loc);
-  } else {
-    firOpBuilder.create<mlir::acc::TerminatorOp>(loc);
-  }
-  firOpBuilder.setInsertionPointToStart(&block);
-}
-
 /// Create empty blocks for the current region.
 /// These blocks replace blocks parented to an enclosing region.
 template <typename... TerminatorOps>

diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 418bf4ee3d15f..e6175ebda40b2 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -375,6 +375,310 @@ getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) {
   llvm::report_fatal_error("Could not find symbol");
 }
 
+/// Used to generate atomic.read operation which is created in existing
+/// location set by builder.
+static inline void
+genAtomicCaptureStatement(Fortran::lower::AbstractConverter &converter,
+                          mlir::Value fromAddress, mlir::Value toAddress,
+                          mlir::Type elementType, mlir::Location loc) {
+  // Generate `atomic.read` operation for atomic assigment statements
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+  firOpBuilder.create<mlir::acc::AtomicReadOp>(
+      loc, fromAddress, toAddress, mlir::TypeAttr::get(elementType));
+}
+
+/// Used to generate atomic.write operation which is created in existing
+/// location set by builder.
+static inline void
+genAtomicWriteStatement(Fortran::lower::AbstractConverter &converter,
+                        mlir::Value lhsAddr, mlir::Value rhsExpr,
+                        mlir::Location loc,
+                        mlir::Value *evaluatedExprValue = nullptr) {
+  // Generate `atomic.write` operation for atomic assignment statements
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+  mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
+  // Create a conversion outside the capture block.
+  auto insertionPoint = firOpBuilder.saveInsertionPoint();
+  firOpBuilder.setInsertionPointAfter(rhsExpr.getDefiningOp());
+  rhsExpr = firOpBuilder.createConvert(loc, varType, rhsExpr);
+  firOpBuilder.restoreInsertionPoint(insertionPoint);
+
+  firOpBuilder.create<mlir::acc::AtomicWriteOp>(loc, lhsAddr, rhsExpr);
+}
+
+/// Used to generate atomic.update operation which is created in existing
+/// location set by builder.
+static inline void genAtomicUpdateStatement(
+    Fortran::lower::AbstractConverter &converter, mlir::Value lhsAddr,
+    mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable,
+    const Fortran::parser::Expr &assignmentStmtExpr, mlir::Location loc,
+    mlir::Operation *atomicCaptureOp = nullptr) {
+  // Generate `atomic.update` operation for atomic assignment statements
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  mlir::Location currentLocation = converter.getCurrentLocation();
+
+  //  Create the omp.atomic.update or acc.atomic.update operation
+  //
+  //  func.func @_QPsb() {
+  //    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
+  //    %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
+  //    %2 = fir.load %1 : !fir.ref<i32>
+  //    omp.atomic.update   %0 : !fir.ref<i32> {
+  //    ^bb0(%arg0: i32):
+  //      %3 = arith.addi %arg0, %2 : i32
+  //      omp.yield(%3 : i32)
+  //    }
+  //    return
+  //  }
+
+  auto getArgExpression =
+      [](std::list<Fortran::parser::ActualArgSpec>::const_iterator it) {
+        const auto &arg{std::get<Fortran::parser::ActualArg>((*it).t)};
+        const auto *parserExpr{
+            std::get_if<Fortran::common::Indirection<Fortran::parser::Expr>>(
+                &arg.u)};
+        return parserExpr;
+      };
+
+  // Lower any non atomic sub-expression before the atomic operation, and
+  // map its lowered value to the semantic representation.
+  Fortran::lower::ExprToValueMap exprValueOverrides;
+  // Max and min intrinsics can have a list of Args. Hence we need a list
+  // of nonAtomicSubExprs to hoist. Currently, only the load is hoisted.
+  llvm::SmallVector<const Fortran::lower::SomeExpr *> nonAtomicSubExprs;
+  Fortran::common::visit(
+      Fortran::common::visitors{
+          [&](const Fortran::common::Indirection<
+              Fortran::parser::FunctionReference> &funcRef) -> void {
+            const auto &args{
+                std::get<std::list<Fortran::parser::ActualArgSpec>>(
+                    funcRef.value().v.t)};
+            std::list<Fortran::parser::ActualArgSpec>::const_iterator beginIt =
+                args.begin();
+            std::list<Fortran::parser::ActualArgSpec>::const_iterator endIt =
+                args.end();
+            const auto *exprFirst{getArgExpression(beginIt)};
+            if (exprFirst && exprFirst->value().source ==
+                                 assignmentStmtVariable.GetSource()) {
+              // Add everything except the first
+              beginIt++;
+            } else {
+              // Add everything except the last
+              endIt--;
+            }
+            std::list<Fortran::parser::ActualArgSpec>::const_iterator it;
+            for (it = beginIt; it != endIt; it++) {
+              const Fortran::common::Indirection<Fortran::parser::Expr> *expr =
+                  getArgExpression(it);
+              if (expr)
+                nonAtomicSubExprs.push_back(Fortran::semantics::GetExpr(*expr));
+            }
+          },
+          [&](const auto &op) -> void {
+            using T = std::decay_t<decltype(op)>;
+            if constexpr (std::is_base_of<
+                              Fortran::parser::Expr::IntrinsicBinary,
+                              T>::value) {
+              const auto &exprLeft{std::get<0>(op.t)};
+              const auto &exprRight{std::get<1>(op.t)};
+              if (exprLeft.value().source == assignmentStmtVariable.GetSource())
+                nonAtomicSubExprs.push_back(
+                    Fortran::semantics::GetExpr(exprRight));
+              else
+                nonAtomicSubExprs.push_back(
+                    Fortran::semantics::GetExpr(exprLeft));
+            }
+          },
+      },
+      assignmentStmtExpr.u);
+  Fortran::lower::StatementContext nonAtomicStmtCtx;
+  if (!nonAtomicSubExprs.empty()) {
+    // Generate non atomic part before all the atomic operations.
+    auto insertionPoint = firOpBuilder.saveInsertionPoint();
+    if (atomicCaptureOp)
+      firOpBuilder.setInsertionPoint(atomicCaptureOp);
+    mlir::Value nonAtomicVal;
+    for (auto *nonAtomicSubExpr : nonAtomicSubExprs) {
+      nonAtomicVal = fir::getBase(converter.genExprValue(
+          currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
+      exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
+    }
+    if (atomicCaptureOp)
+      firOpBuilder.restoreInsertionPoint(insertionPoint);
+  }
+
+  mlir::Operation *atomicUpdateOp = nullptr;
+  atomicUpdateOp =
+      firOpBuilder.create<mlir::acc::AtomicUpdateOp>(currentLocation, lhsAddr);
+
+  llvm::SmallVector<mlir::Type> varTys = {varType};
+  llvm::SmallVector<mlir::Location> locs = {currentLocation};
+  firOpBuilder.createBlock(&atomicUpdateOp->getRegion(0), {}, varTys, locs);
+  mlir::Value val =
+      fir::getBase(atomicUpdateOp->getRegion(0).front().getArgument(0));
+
+  exprValueOverrides.try_emplace(
+      Fortran::semantics::GetExpr(assignmentStmtVariable), val);
+  {
+    // statement context inside the atomic block.
+    converter.overrideExprValues(&exprValueOverrides);
+    Fortran::lower::StatementContext atomicStmtCtx;
+    mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
+        *Fortran::semantics::GetExpr(assignmentStmtExpr), atomicStmtCtx));
+    mlir::Value convertResult =
+        firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
+    firOpBuilder.create<mlir::acc::YieldOp>(currentLocation, convertResult);
+    converter.resetExprOverrides();
+  }
+  firOpBuilder.setInsertionPointAfter(atomicUpdateOp);
+}
+
+/// Processes an atomic construct with write clause.
+void genAtomicWrite(Fortran::lower::AbstractConverter &converter,
+                    const Fortran::parser::AccAtomicWrite &atomicWrite,
+                    mlir::Location loc) {
+  const Fortran::parser::AssignmentStmt &stmt =
+      std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
+          atomicWrite.t)
+          .statement;
+  const Fortran::evaluate::Assignment &assign = *stmt.typedAssignment->v;
+  Fortran::lower::StatementContext stmtCtx;
+  // Get the value and address of atomic write operands.
+  mlir::Value rhsExpr =
+      fir::getBase(converter.genExprValue(assign.rhs, stmtCtx));
+  mlir::Value lhsAddr =
+      fir::getBase(converter.genExprAddr(assign.lhs, stmtCtx));
+  genAtomicWriteStatement(converter, lhsAddr, rhsExpr, loc);
+}
+
+/// Processes an atomic construct with read clause.
+void genAtomicRead(Fortran::lower::AbstractConverter &converter,
+                   const Fortran::parser::AccAtomicRead &atomicRead,
+                   mlir::Location loc) {
+  const auto &assignmentStmtExpr = std::get<Fortran::parser::Expr>(
+      std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
+          atomicRead.t)
+          .statement.t);
+  const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>(
+      std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
+          atomicRead.t)
+          .statement.t);
+
+  Fortran::lower::StatementContext stmtCtx;
+  const Fortran::semantics::SomeExpr &fromExpr =
+      *Fortran::semantics::GetExpr(assignmentStmtExpr);
+  mlir::Type elementType = converter.genType(fromExpr);
+  mlir::Value fromAddress =
+      fir::getBase(converter.genExprAddr(fromExpr, stmtCtx));
+  mlir::Value toAddress = fir::getBase(converter.genExprAddr(
+      *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
+  genAtomicCaptureStatement(converter, fromAddress, toAddress, elementType,
+                            loc);
+}
+
+/// Processes an atomic construct with update clause.
+void genAtomicUpdate(Fortran::lower::AbstractConverter &converter,
+                     const Fortran::parser::AccAtomicUpdate &atomicUpdate,
+                     mlir::Location loc) {
+  const auto &assignmentStmtExpr = std::get<Fortran::parser::Expr>(
+      std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
+          atomicUpdate.t)
+          .statement.t);
+  const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>(
+      std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
+          atomicUpdate.t)
+          .statement.t);
+
+  Fortran::lower::StatementContext stmtCtx;
+  mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
+      *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
+  mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
+  genAtomicUpdateStatement(converter, lhsAddr, varType, assignmentStmtVariable,
+                           assignmentStmtExpr, loc);
+}
+
+/// Processes an atomic construct with capture clause.
+void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
+                      const Fortran::parser::AccAtomicCapture &atomicCapture,
+                      mlir::Location loc) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+  const Fortran::parser::AssignmentStmt &stmt1 =
+      std::get<Fortran::parser::AccAtomicCapture::Stmt1>(atomicCapture.t)
+          .v.statement;
+  const Fortran::evaluate::Assignment &assign1 = *stmt1.typedAssignment->v;
+  const auto &stmt1Var{std::get<Fortran::parser::Variable>(stmt1.t)};
+  const auto &stmt1Expr{std::get<Fortran::parser::Expr>(stmt1.t)};
+  const Fortran::parser::AssignmentStmt &stmt2 =
+      std::get<Fortran::parser::AccAtomicCapture::Stmt2>(atomicCapture.t)
+          .v.statement;
+  const Fortran::evaluate::Assignment &assign2 = *stmt2.typedAssignment->v;
+  const auto &stmt2Var{std::get<Fortran::parser::Variable>(stmt2.t)};
+  const auto &stmt2Expr{std::get<Fortran::parser::Expr>(stmt2.t)};
+
+  // Pre-evaluate expressions to be used in the various operations inside
+  // `atomic.capture` since it is not desirable to have anything other than
+  // a `atomic.read`, `atomic.write`, or `atomic.update` operation
+  // inside `atomic.capture`
+  Fortran::lower::StatementContext stmtCtx;
+  // LHS evaluations are common to all combinations of `atomic.capture`
+  mlir::Value stmt1LHSArg =
+      fir::getBase(converter.genExprAddr(assign1.lhs, stmtCtx));
+  mlir::Value stmt2LHSArg =
+      fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
+
+  // Type information used in generation of `atomic.update` operation
+  mlir::Type stmt1VarType =
+      fir::getBase(converter.genExprValue(assign1.lhs, stmtCtx)).getType();
+  mlir::Type stmt2VarType =
+      fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType();
+
+  mlir::Operation *atomicCaptureOp = nullptr;
+  atomicCaptureOp = firOpBuilder.create<mlir::acc::AtomicCaptureOp>(loc);
+
+  firOpBuilder.createBlock(&(atomicCaptureOp->getRegion(0)));
+  mlir::Block &block = atomicCaptureOp->getRegion(0).back();
+  firOpBuilder.setInsertionPointToStart(&block);
+  if (Fortran::semantics::checkForSingleVariableOnRHS(stmt1)) {
+    if (Fortran::semantics::checkForSymbolMatch(stmt2)) {
+      // Atomic capture construct is of the form [capture-stmt, update-stmt]
+      const Fortran::semantics::SomeExpr &fromExpr =
+          *Fortran::semantics::GetExpr(stmt1Expr);
+      mlir::Type elementType = converter.genType(fromExpr);
+      genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg,
+                                elementType, loc);
+      genAtomicUpdateStatement(converter, stmt2LHSArg, stmt2VarType, stmt2Var,
+                               stmt2Expr, loc, atomicCaptureOp);
+    } else {
+      // Atomic capture construct is of the form [capture-stmt, write-stmt]
+      firOpBuilder.setInsertionPoint(atomicCaptureOp);
+      mlir::Value stmt2RHSArg =
+          fir::getBase(converter.genExprValue(assign2.rhs, stmtCtx));
+      firOpBuilder.setInsertionPointToStart(&block);
+      const Fortran::semantics::SomeExpr &fromExpr =
+          *Fortran::semantics::GetExpr(stmt1Expr);
+      mlir::Type elementType = converter.genType(fromExpr);
+      genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg,
+                                elementType, loc);
+      genAtomicWriteStatement(converter, stmt2LHSArg, stmt2RHSArg, loc);
+    }
+  } else {
+    // Atomic capture construct is of the form [update-stmt, capture-stmt]
+    const Fortran::semantics::SomeExpr &fromExpr =
+        *Fortran::semantics::GetExpr(stmt2Expr);
+    mlir::Type elementType = converter.genType(fromExpr);
+    genAtomicUpdateStatement(converter, stmt1LHSArg, stmt1VarType, stmt1Var,
+                             stmt1Expr, loc, atomicCaptureOp);
+    genAtomicCaptureStatement(converter, stmt1LHSArg, stmt2LHSArg, elementType,
+                              loc);
+  }
+  firOpBuilder.setInsertionPointToEnd(&block);
+  firOpBuilder.create<mlir::acc::TerminatorOp>(loc);
+  firOpBuilder.setInsertionPointToStart(&block);
+}
+
 template <typename Op>
 static void
 genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
@@ -4352,24 +4656,16 @@ genACC(Fortran::lower::AbstractConverter &converter,
   Fortran::common::visit(
       Fortran::common::visitors{
           [&](const Fortran::parser::AccAtomicRead &atomicRead) {
-            Fortran::lower::genOmpAccAtomicRead<Fortran::parser::AccAtomicRead,
-                                                void>(converter, atomicRead,
-                                                      loc);
+            genAtomicRead(converter, atomicRead, loc);
           },
           [&](const Fortran::parser::AccAtomicWrite &atomicWrite) {
-            Fortran::lower::genOmpAccAtomicWrite<
-                Fortran::parser::AccAtomicWrite, void>(converter, atomicWrite,
-                                                       loc);
+            genAtomicWrite(converter, atomicWrite, loc);
           },
           [&](const Fortran::parser::AccAtomicUpdate &atomicUpdate) {
-            Fortran::lower::genOmpAccAtomicUpdate<
-                Fortran::parser::AccAtomicUpdate, void>(converter, atomicUpdate,
-                                                        loc);
+            genAtomicUpdate(converter, atomicUpdate, loc);
           },
           [&](const Fortran::parser::AccAtomicCapture &atomicCapture) {
-            Fortran::lower::genOmpAccAtomicCapture<
-                Fortran::parser::AccAtomicCapture, void>(converter,
-                                                         atomicCapture, loc);
+            genAtomicCapture(converter, atomicCapture, loc);
           },
       },
       atomicConstruct.u);

diff  --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 312557d5da07e..fdd85e94829f3 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2585,6 +2585,460 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       queue, item, clauseOps);
 }
 
+//===----------------------------------------------------------------------===//
+// Code generation for atomic operations
+//===----------------------------------------------------------------------===//
+
+/// Populates \p hint and \p memoryOrder with appropriate clause information
+/// if present on atomic construct.
+static void genOmpAtomicHintAndMemoryOrderClauses(
+    lower::AbstractConverter &converter,
+    const parser::OmpAtomicClauseList &clauseList, mlir::IntegerAttr &hint,
+    mlir::omp::ClauseMemoryOrderKindAttr &memoryOrder) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  for (const parser::OmpAtomicClause &clause : clauseList.v) {
+    common::visit(
+        common::visitors{
+            [&](const parser::OmpMemoryOrderClause &s) {
+              auto kind = common::visit(
+                  common::visitors{
+                      [&](const parser::OmpClause::AcqRel &) {
+                        return mlir::omp::ClauseMemoryOrderKind::Acq_rel;
+                      },
+                      [&](const parser::OmpClause::Acquire &) {
+                        return mlir::omp::ClauseMemoryOrderKind::Acquire;
+                      },
+                      [&](const parser::OmpClause::Relaxed &) {
+                        return mlir::omp::ClauseMemoryOrderKind::Relaxed;
+                      },
+                      [&](const parser::OmpClause::Release &) {
+                        return mlir::omp::ClauseMemoryOrderKind::Release;
+                      },
+                      [&](const parser::OmpClause::SeqCst &) {
+                        return mlir::omp::ClauseMemoryOrderKind::Seq_cst;
+                      },
+                      [&](auto &&) -> mlir::omp::ClauseMemoryOrderKind {
+                        llvm_unreachable("Unexpected clause");
+                      },
+                  },
+                  s.v.u);
+              memoryOrder = mlir::omp::ClauseMemoryOrderKindAttr::get(
+                  firOpBuilder.getContext(), kind);
+            },
+            [&](const parser::OmpHintClause &s) {
+              const auto *expr = semantics::GetExpr(s.v);
+              uint64_t hintExprValue = *evaluate::ToInt64(*expr);
+              hint = firOpBuilder.getI64IntegerAttr(hintExprValue);
+            },
+            [&](const parser::OmpFailClause &) {},
+        },
+        clause.u);
+  }
+}
+
+static void processOmpAtomicTODO(mlir::Type elementType, mlir::Location loc) {
+  if (!elementType)
+    return;
+  assert(fir::isa_trivial(fir::unwrapRefType(elementType)) &&
+         "is supported type for omp atomic");
+}
+
+/// Used to generate atomic.read operation which is created in existing
+/// location set by builder.
+static void genAtomicCaptureStatement(
+    lower::AbstractConverter &converter, mlir::Value fromAddress,
+    mlir::Value toAddress,
+    const parser::OmpAtomicClauseList *leftHandClauseList,
+    const parser::OmpAtomicClauseList *rightHandClauseList,
+    mlir::Type elementType, mlir::Location loc) {
+  // Generate `atomic.read` operation for atomic assigment statements
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+  processOmpAtomicTODO(elementType, loc);
+
+  // If no hint clause is specified, the effect is as if
+  // hint(omp_sync_hint_none) had been specified.
+  mlir::IntegerAttr hint = nullptr;
+
+  mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr;
+  if (leftHandClauseList)
+    genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint,
+                                          memoryOrder);
+  if (rightHandClauseList)
+    genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint,
+                                          memoryOrder);
+  firOpBuilder.create<mlir::omp::AtomicReadOp>(loc, fromAddress, toAddress,
+                                               mlir::TypeAttr::get(elementType),
+                                               hint, memoryOrder);
+}
+
+/// Used to generate atomic.write operation which is created in existing
+/// location set by builder.
+static void genAtomicWriteStatement(
+    lower::AbstractConverter &converter, mlir::Value lhsAddr,
+    mlir::Value rhsExpr, const parser::OmpAtomicClauseList *leftHandClauseList,
+    const parser::OmpAtomicClauseList *rightHandClauseList, mlir::Location loc,
+    mlir::Value *evaluatedExprValue = nullptr) {
+  // Generate `atomic.write` operation for atomic assignment statements
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+  mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
+  // Create a conversion outside the capture block.
+  auto insertionPoint = firOpBuilder.saveInsertionPoint();
+  firOpBuilder.setInsertionPointAfter(rhsExpr.getDefiningOp());
+  rhsExpr = firOpBuilder.createConvert(loc, varType, rhsExpr);
+  firOpBuilder.restoreInsertionPoint(insertionPoint);
+
+  processOmpAtomicTODO(varType, loc);
+
+  // If no hint clause is specified, the effect is as if
+  // hint(omp_sync_hint_none) had been specified.
+  mlir::IntegerAttr hint = nullptr;
+  mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr;
+  if (leftHandClauseList)
+    genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint,
+                                          memoryOrder);
+  if (rightHandClauseList)
+    genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint,
+                                          memoryOrder);
+  firOpBuilder.create<mlir::omp::AtomicWriteOp>(loc, lhsAddr, rhsExpr, hint,
+                                                memoryOrder);
+}
+
+/// Used to generate atomic.update operation which is created in existing
+/// location set by builder.
+static void genAtomicUpdateStatement(
+    lower::AbstractConverter &converter, mlir::Value lhsAddr,
+    mlir::Type varType, const parser::Variable &assignmentStmtVariable,
+    const parser::Expr &assignmentStmtExpr,
+    const parser::OmpAtomicClauseList *leftHandClauseList,
+    const parser::OmpAtomicClauseList *rightHandClauseList, mlir::Location loc,
+    mlir::Operation *atomicCaptureOp = nullptr) {
+  // Generate `atomic.update` operation for atomic assignment statements
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  mlir::Location currentLocation = converter.getCurrentLocation();
+
+  //  Create the omp.atomic.update or acc.atomic.update operation
+  //
+  //  func.func @_QPsb() {
+  //    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
+  //    %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
+  //    %2 = fir.load %1 : !fir.ref<i32>
+  //    omp.atomic.update   %0 : !fir.ref<i32> {
+  //    ^bb0(%arg0: i32):
+  //      %3 = arith.addi %arg0, %2 : i32
+  //      omp.yield(%3 : i32)
+  //    }
+  //    return
+  //  }
+
+  auto getArgExpression =
+      [](std::list<parser::ActualArgSpec>::const_iterator it) {
+        const auto &arg{std::get<parser::ActualArg>((*it).t)};
+        const auto *parserExpr{
+            std::get_if<common::Indirection<parser::Expr>>(&arg.u)};
+        return parserExpr;
+      };
+
+  // Lower any non atomic sub-expression before the atomic operation, and
+  // map its lowered value to the semantic representation.
+  lower::ExprToValueMap exprValueOverrides;
+  // Max and min intrinsics can have a list of Args. Hence we need a list
+  // of nonAtomicSubExprs to hoist. Currently, only the load is hoisted.
+  llvm::SmallVector<const lower::SomeExpr *> nonAtomicSubExprs;
+  common::visit(
+      common::visitors{
+          [&](const common::Indirection<parser::FunctionReference> &funcRef)
+              -> void {
+            const auto &args{std::get<std::list<parser::ActualArgSpec>>(
+                funcRef.value().v.t)};
+            std::list<parser::ActualArgSpec>::const_iterator beginIt =
+                args.begin();
+            std::list<parser::ActualArgSpec>::const_iterator endIt = args.end();
+            const auto *exprFirst{getArgExpression(beginIt)};
+            if (exprFirst && exprFirst->value().source ==
+                                 assignmentStmtVariable.GetSource()) {
+              // Add everything except the first
+              beginIt++;
+            } else {
+              // Add everything except the last
+              endIt--;
+            }
+            std::list<parser::ActualArgSpec>::const_iterator it;
+            for (it = beginIt; it != endIt; it++) {
+              const common::Indirection<parser::Expr> *expr =
+                  getArgExpression(it);
+              if (expr)
+                nonAtomicSubExprs.push_back(semantics::GetExpr(*expr));
+            }
+          },
+          [&](const auto &op) -> void {
+            using T = std::decay_t<decltype(op)>;
+            if constexpr (std::is_base_of<parser::Expr::IntrinsicBinary,
+                                          T>::value) {
+              const auto &exprLeft{std::get<0>(op.t)};
+              const auto &exprRight{std::get<1>(op.t)};
+              if (exprLeft.value().source == assignmentStmtVariable.GetSource())
+                nonAtomicSubExprs.push_back(semantics::GetExpr(exprRight));
+              else
+                nonAtomicSubExprs.push_back(semantics::GetExpr(exprLeft));
+            }
+          },
+      },
+      assignmentStmtExpr.u);
+  lower::StatementContext nonAtomicStmtCtx;
+  if (!nonAtomicSubExprs.empty()) {
+    // Generate non atomic part before all the atomic operations.
+    auto insertionPoint = firOpBuilder.saveInsertionPoint();
+    if (atomicCaptureOp)
+      firOpBuilder.setInsertionPoint(atomicCaptureOp);
+    mlir::Value nonAtomicVal;
+    for (auto *nonAtomicSubExpr : nonAtomicSubExprs) {
+      nonAtomicVal = fir::getBase(converter.genExprValue(
+          currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
+      exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
+    }
+    if (atomicCaptureOp)
+      firOpBuilder.restoreInsertionPoint(insertionPoint);
+  }
+
+  mlir::Operation *atomicUpdateOp = nullptr;
+  // If no hint clause is specified, the effect is as if
+  // hint(omp_sync_hint_none) had been specified.
+  mlir::IntegerAttr hint = nullptr;
+  mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr;
+  if (leftHandClauseList)
+    genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint,
+                                          memoryOrder);
+  if (rightHandClauseList)
+    genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint,
+                                          memoryOrder);
+  atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>(
+      currentLocation, lhsAddr, hint, memoryOrder);
+
+  processOmpAtomicTODO(varType, loc);
+
+  llvm::SmallVector<mlir::Type> varTys = {varType};
+  llvm::SmallVector<mlir::Location> locs = {currentLocation};
+  firOpBuilder.createBlock(&atomicUpdateOp->getRegion(0), {}, varTys, locs);
+  mlir::Value val =
+      fir::getBase(atomicUpdateOp->getRegion(0).front().getArgument(0));
+
+  exprValueOverrides.try_emplace(semantics::GetExpr(assignmentStmtVariable),
+                                 val);
+  {
+    // statement context inside the atomic block.
+    converter.overrideExprValues(&exprValueOverrides);
+    lower::StatementContext atomicStmtCtx;
+    mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
+        *semantics::GetExpr(assignmentStmtExpr), atomicStmtCtx));
+    mlir::Value convertResult =
+        firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
+    firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, convertResult);
+    converter.resetExprOverrides();
+  }
+  firOpBuilder.setInsertionPointAfter(atomicUpdateOp);
+}
+
+/// Processes an atomic construct with write clause.
+static void genAtomicWrite(lower::AbstractConverter &converter,
+                           const parser::OmpAtomicWrite &atomicWrite,
+                           mlir::Location loc) {
+  const parser::OmpAtomicClauseList *rightHandClauseList = nullptr;
+  const parser::OmpAtomicClauseList *leftHandClauseList = nullptr;
+  // Get the address of atomic read operands.
+  rightHandClauseList = &std::get<2>(atomicWrite.t);
+  leftHandClauseList = &std::get<0>(atomicWrite.t);
+
+  const parser::AssignmentStmt &stmt =
+      std::get<parser::Statement<parser::AssignmentStmt>>(atomicWrite.t)
+          .statement;
+  const evaluate::Assignment &assign = *stmt.typedAssignment->v;
+  lower::StatementContext stmtCtx;
+  // Get the value and address of atomic write operands.
+  mlir::Value rhsExpr =
+      fir::getBase(converter.genExprValue(assign.rhs, stmtCtx));
+  mlir::Value lhsAddr =
+      fir::getBase(converter.genExprAddr(assign.lhs, stmtCtx));
+  genAtomicWriteStatement(converter, lhsAddr, rhsExpr, leftHandClauseList,
+                          rightHandClauseList, loc);
+}
+
+/// Processes an atomic construct with read clause.
+static void genAtomicRead(lower::AbstractConverter &converter,
+                          const parser::OmpAtomicRead &atomicRead,
+                          mlir::Location loc) {
+  const parser::OmpAtomicClauseList *rightHandClauseList = nullptr;
+  const parser::OmpAtomicClauseList *leftHandClauseList = nullptr;
+  // Get the address of atomic read operands.
+  rightHandClauseList = &std::get<2>(atomicRead.t);
+  leftHandClauseList = &std::get<0>(atomicRead.t);
+
+  const auto &assignmentStmtExpr = std::get<parser::Expr>(
+      std::get<parser::Statement<parser::AssignmentStmt>>(atomicRead.t)
+          .statement.t);
+  const auto &assignmentStmtVariable = std::get<parser::Variable>(
+      std::get<parser::Statement<parser::AssignmentStmt>>(atomicRead.t)
+          .statement.t);
+
+  lower::StatementContext stmtCtx;
+  const semantics::SomeExpr &fromExpr = *semantics::GetExpr(assignmentStmtExpr);
+  mlir::Type elementType = converter.genType(fromExpr);
+  mlir::Value fromAddress =
+      fir::getBase(converter.genExprAddr(fromExpr, stmtCtx));
+  mlir::Value toAddress = fir::getBase(converter.genExprAddr(
+      *semantics::GetExpr(assignmentStmtVariable), stmtCtx));
+  genAtomicCaptureStatement(converter, fromAddress, toAddress,
+                            leftHandClauseList, rightHandClauseList,
+                            elementType, loc);
+}
+
+/// Processes an atomic construct with update clause.
+static void genAtomicUpdate(lower::AbstractConverter &converter,
+                            const parser::OmpAtomicUpdate &atomicUpdate,
+                            mlir::Location loc) {
+  const parser::OmpAtomicClauseList *rightHandClauseList = nullptr;
+  const parser::OmpAtomicClauseList *leftHandClauseList = nullptr;
+  // Get the address of atomic read operands.
+  rightHandClauseList = &std::get<2>(atomicUpdate.t);
+  leftHandClauseList = &std::get<0>(atomicUpdate.t);
+
+  const auto &assignmentStmtExpr = std::get<parser::Expr>(
+      std::get<parser::Statement<parser::AssignmentStmt>>(atomicUpdate.t)
+          .statement.t);
+  const auto &assignmentStmtVariable = std::get<parser::Variable>(
+      std::get<parser::Statement<parser::AssignmentStmt>>(atomicUpdate.t)
+          .statement.t);
+
+  lower::StatementContext stmtCtx;
+  mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
+      *semantics::GetExpr(assignmentStmtVariable), stmtCtx));
+  mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
+  genAtomicUpdateStatement(converter, lhsAddr, varType, assignmentStmtVariable,
+                           assignmentStmtExpr, leftHandClauseList,
+                           rightHandClauseList, loc);
+}
+
+/// Processes an atomic construct with no clause - which implies update clause.
+static void genOmpAtomic(lower::AbstractConverter &converter,
+                         const parser::OmpAtomic &atomicConstruct,
+                         mlir::Location loc) {
+  const parser::OmpAtomicClauseList &atomicClauseList =
+      std::get<parser::OmpAtomicClauseList>(atomicConstruct.t);
+  const auto &assignmentStmtExpr = std::get<parser::Expr>(
+      std::get<parser::Statement<parser::AssignmentStmt>>(atomicConstruct.t)
+          .statement.t);
+  const auto &assignmentStmtVariable = std::get<parser::Variable>(
+      std::get<parser::Statement<parser::AssignmentStmt>>(atomicConstruct.t)
+          .statement.t);
+  lower::StatementContext stmtCtx;
+  mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
+      *semantics::GetExpr(assignmentStmtVariable), stmtCtx));
+  mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
+  // If atomic-clause is not present on the construct, the behaviour is as if
+  // the update clause is specified (for both OpenMP and OpenACC).
+  genAtomicUpdateStatement(converter, lhsAddr, varType, assignmentStmtVariable,
+                           assignmentStmtExpr, &atomicClauseList, nullptr, loc);
+}
+
+/// Processes an atomic construct with capture clause.
+static void genAtomicCapture(lower::AbstractConverter &converter,
+                             const parser::OmpAtomicCapture &atomicCapture,
+                             mlir::Location loc) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+  const parser::AssignmentStmt &stmt1 =
+      std::get<parser::OmpAtomicCapture::Stmt1>(atomicCapture.t).v.statement;
+  const evaluate::Assignment &assign1 = *stmt1.typedAssignment->v;
+  const auto &stmt1Var{std::get<parser::Variable>(stmt1.t)};
+  const auto &stmt1Expr{std::get<parser::Expr>(stmt1.t)};
+  const parser::AssignmentStmt &stmt2 =
+      std::get<parser::OmpAtomicCapture::Stmt2>(atomicCapture.t).v.statement;
+  const evaluate::Assignment &assign2 = *stmt2.typedAssignment->v;
+  const auto &stmt2Var{std::get<parser::Variable>(stmt2.t)};
+  const auto &stmt2Expr{std::get<parser::Expr>(stmt2.t)};
+
+  // Pre-evaluate expressions to be used in the various operations inside
+  // `atomic.capture` since it is not desirable to have anything other than
+  // a `atomic.read`, `atomic.write`, or `atomic.update` operation
+  // inside `atomic.capture`
+  lower::StatementContext stmtCtx;
+  // LHS evaluations are common to all combinations of `atomic.capture`
+  mlir::Value stmt1LHSArg =
+      fir::getBase(converter.genExprAddr(assign1.lhs, stmtCtx));
+  mlir::Value stmt2LHSArg =
+      fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
+
+  // Type information used in generation of `atomic.update` operation
+  mlir::Type stmt1VarType =
+      fir::getBase(converter.genExprValue(assign1.lhs, stmtCtx)).getType();
+  mlir::Type stmt2VarType =
+      fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType();
+
+  mlir::Operation *atomicCaptureOp = nullptr;
+  mlir::IntegerAttr hint = nullptr;
+  mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr;
+  const parser::OmpAtomicClauseList &rightHandClauseList =
+      std::get<2>(atomicCapture.t);
+  const parser::OmpAtomicClauseList &leftHandClauseList =
+      std::get<0>(atomicCapture.t);
+  genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint,
+                                        memoryOrder);
+  genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint,
+                                        memoryOrder);
+  atomicCaptureOp =
+      firOpBuilder.create<mlir::omp::AtomicCaptureOp>(loc, hint, memoryOrder);
+
+  firOpBuilder.createBlock(&(atomicCaptureOp->getRegion(0)));
+  mlir::Block &block = atomicCaptureOp->getRegion(0).back();
+  firOpBuilder.setInsertionPointToStart(&block);
+  if (semantics::checkForSingleVariableOnRHS(stmt1)) {
+    if (semantics::checkForSymbolMatch(stmt2)) {
+      // Atomic capture construct is of the form [capture-stmt, update-stmt]
+      const semantics::SomeExpr &fromExpr = *semantics::GetExpr(stmt1Expr);
+      mlir::Type elementType = converter.genType(fromExpr);
+      genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg,
+                                /*leftHandClauseList=*/nullptr,
+                                /*rightHandClauseList=*/nullptr, elementType,
+                                loc);
+      genAtomicUpdateStatement(
+          converter, stmt2LHSArg, stmt2VarType, stmt2Var, stmt2Expr,
+          /*leftHandClauseList=*/nullptr,
+          /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
+    } else {
+      // Atomic capture construct is of the form [capture-stmt, write-stmt]
+      firOpBuilder.setInsertionPoint(atomicCaptureOp);
+      mlir::Value stmt2RHSArg =
+          fir::getBase(converter.genExprValue(assign2.rhs, stmtCtx));
+      firOpBuilder.setInsertionPointToStart(&block);
+      const semantics::SomeExpr &fromExpr = *semantics::GetExpr(stmt1Expr);
+      mlir::Type elementType = converter.genType(fromExpr);
+      genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg,
+                                /*leftHandClauseList=*/nullptr,
+                                /*rightHandClauseList=*/nullptr, elementType,
+                                loc);
+      genAtomicWriteStatement(converter, stmt2LHSArg, stmt2RHSArg,
+                              /*leftHandClauseList=*/nullptr,
+                              /*rightHandClauseList=*/nullptr, loc);
+    }
+  } else {
+    // Atomic capture construct is of the form [update-stmt, capture-stmt]
+    const semantics::SomeExpr &fromExpr = *semantics::GetExpr(stmt2Expr);
+    mlir::Type elementType = converter.genType(fromExpr);
+    genAtomicUpdateStatement(
+        converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
+        /*leftHandClauseList=*/nullptr,
+        /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
+    genAtomicCaptureStatement(converter, stmt1LHSArg, stmt2LHSArg,
+                              /*leftHandClauseList=*/nullptr,
+                              /*rightHandClauseList=*/nullptr, elementType,
+                              loc);
+  }
+  firOpBuilder.setInsertionPointToEnd(&block);
+  firOpBuilder.create<mlir::omp::TerminatorOp>(loc);
+  firOpBuilder.setInsertionPointToStart(&block);
+}
+
 //===----------------------------------------------------------------------===//
 // Code generation functions for the standalone version of constructs that can
 // also be a leaf of a composite construct
@@ -3476,32 +3930,23 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
       common::visitors{
           [&](const parser::OmpAtomicRead &atomicRead) {
             mlir::Location loc = converter.genLocation(atomicRead.source);
-            lower::genOmpAccAtomicRead<parser::OmpAtomicRead,
-                                       parser::OmpAtomicClauseList>(
-                converter, atomicRead, loc);
+            genAtomicRead(converter, atomicRead, loc);
           },
           [&](const parser::OmpAtomicWrite &atomicWrite) {
             mlir::Location loc = converter.genLocation(atomicWrite.source);
-            lower::genOmpAccAtomicWrite<parser::OmpAtomicWrite,
-                                        parser::OmpAtomicClauseList>(
-                converter, atomicWrite, loc);
+            genAtomicWrite(converter, atomicWrite, loc);
           },
           [&](const parser::OmpAtomic &atomicConstruct) {
             mlir::Location loc = converter.genLocation(atomicConstruct.source);
-            lower::genOmpAtomic<parser::OmpAtomic, parser::OmpAtomicClauseList>(
-                converter, atomicConstruct, loc);
+            genOmpAtomic(converter, atomicConstruct, loc);
           },
           [&](const parser::OmpAtomicUpdate &atomicUpdate) {
             mlir::Location loc = converter.genLocation(atomicUpdate.source);
-            lower::genOmpAccAtomicUpdate<parser::OmpAtomicUpdate,
-                                         parser::OmpAtomicClauseList>(
-                converter, atomicUpdate, loc);
+            genAtomicUpdate(converter, atomicUpdate, loc);
           },
           [&](const parser::OmpAtomicCapture &atomicCapture) {
             mlir::Location loc = converter.genLocation(atomicCapture.source);
-            lower::genOmpAccAtomicCapture<parser::OmpAtomicCapture,
-                                          parser::OmpAtomicClauseList>(
-                converter, atomicCapture, loc);
+            genAtomicCapture(converter, atomicCapture, loc);
           },
           [&](const parser::OmpAtomicCompare &atomicCompare) {
             mlir::Location loc = converter.genLocation(atomicCompare.source);


        


More information about the flang-commits mailing list