[flang-commits] [flang] [Flang][OpenMP][OpenACC] Handle atomic read/capture when lhs and rhs … (PR #93776)

via flang-commits flang-commits at lists.llvm.org
Thu May 30 00:16:58 PDT 2024


https://github.com/harishch4 created https://github.com/llvm/llvm-project/pull/93776

…types are different

Fixes : #83722

Changed evaluated expression to typed expression for atomic reads to keep it consistent with clang. Atomic loads now happen on the actual memory location rather than a converted value.

Handled generating hlfir for complex types as well, but lowering it to LLVM IR is still a WIP (That should fix #93441).

>From ba975743902a6bbcc6e44415d3cdf90b930fcbb1 Mon Sep 17 00:00:00 2001
From: Harish Chambeti <harishcse44 at gmail.com>
Date: Thu, 30 May 2024 12:29:28 +0530
Subject: [PATCH] [Flang][OpenMP][OpenACC] Handle atomic read/capture when lhs
 and rhs types are different

---
 flang/lib/Lower/DirectivesCommon.h           | 72 +++++++++++++++-----
 flang/test/Lower/OpenACC/acc-atomic-read.f90 |  8 ++-
 flang/test/Lower/OpenMP/atomic-capture.f90   | 30 ++++++++
 flang/test/Lower/OpenMP/atomic-read.f90      | 37 ++++++++++
 4 files changed, 129 insertions(+), 18 deletions(-)

diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 48b090f6d2dbe..d97ae2c5f51f4 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -30,6 +30,7 @@
 #include "flang/Lower/StatementContext.h"
 #include "flang/Lower/Support/Utils.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/Complex.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Builder/Todo.h"
@@ -143,9 +144,24 @@ static inline void genOmpAccAtomicCaptureStatement(
     mlir::Value toAddress,
     [[maybe_unused]] const AtomicListT *leftHandClauseList,
     [[maybe_unused]] const AtomicListT *rightHandClauseList,
-    mlir::Type elementType, mlir::Location loc) {
+    mlir::Type elementType, mlir::Location loc,
+    mlir::Operation *atomicCaptureOp = nullptr) {
   // Generate `atomic.read` operation for atomic assigment statements
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  mlir::Value oldToAddress = toAddress;
+  if (fromAddress.getType() != oldToAddress.getType()) {
+    auto insertionPoint = firOpBuilder.saveInsertionPoint();
+    if (atomicCaptureOp)
+      firOpBuilder.setInsertionPoint(atomicCaptureOp);
+    auto alloca = firOpBuilder.create<fir::AllocaOp>(loc, elementType);
+    auto declareOp = firOpBuilder.create<hlfir::DeclareOp>(
+        loc, alloca, ".atomic.read.temp", /*shape=*/nullptr,
+        llvm::ArrayRef<mlir::Value>{},
+        /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{});
+    toAddress = declareOp.getBase();
+    if (atomicCaptureOp)
+      firOpBuilder.restoreInsertionPoint(insertionPoint);
+  }
 
   if constexpr (std::is_same<AtomicListT,
                              Fortran::parser::OmpAtomicClauseList>()) {
@@ -167,6 +183,24 @@ static inline void genOmpAccAtomicCaptureStatement(
     firOpBuilder.create<mlir::acc::AtomicReadOp>(
         loc, fromAddress, toAddress, mlir::TypeAttr::get(elementType));
   }
+
+  if (fromAddress.getType() != oldToAddress.getType()) {
+    auto insertionPoint = firOpBuilder.saveInsertionPoint();
+    if (atomicCaptureOp)
+      firOpBuilder.setInsertionPointAfter(atomicCaptureOp);
+    mlir::Value load = firOpBuilder.create<fir::LoadOp>(loc, toAddress);
+    if (auto cmplxTy = mlir::dyn_cast_or_null<fir::ComplexType>(elementType)) {
+      mlir::Value extractValue =
+          fir::factory::Complex{firOpBuilder, loc}.extractComplexPart(load,
+                                                                      false);
+      load = extractValue;
+    }
+    mlir::Value convert = firOpBuilder.create<fir::ConvertOp>(
+        loc, fir::unwrapRefType(oldToAddress.getType()), load);
+    firOpBuilder.create<fir::StoreOp>(loc, convert, oldToAddress);
+    if (atomicCaptureOp)
+      firOpBuilder.restoreInsertionPoint(insertionPoint);
+  }
 }
 
 /// Used to generate atomic.write operation which is created in existing
@@ -408,10 +442,6 @@ void genOmpAccAtomicRead(Fortran::lower::AbstractConverter &converter,
       fir::getBase(converter.genExprAddr(fromExpr, stmtCtx));
   mlir::Value toAddress = fir::getBase(converter.genExprAddr(
       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
-  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
-  if (fromAddress.getType() != toAddress.getType())
-    fromAddress =
-        builder.create<fir::ConvertOp>(loc, toAddress.getType(), fromAddress);
   genOmpAccAtomicCaptureStatement(converter, fromAddress, toAddress,
                                   leftHandClauseList, rightHandClauseList,
                                   elementType, loc);
@@ -481,12 +511,10 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
 
   const Fortran::parser::AssignmentStmt &stmt1 =
       std::get<typename AtomicT::Stmt1>(atomicCapture.t).v.statement;
-  const Fortran::evaluate::Assignment &assign1 = *stmt1.typedAssignment->v;
   const auto &stmt1Var{std::get<Fortran::parser::Variable>(stmt1.t)};
   const auto &stmt1Expr{std::get<Fortran::parser::Expr>(stmt1.t)};
   const Fortran::parser::AssignmentStmt &stmt2 =
       std::get<typename AtomicT::Stmt2>(atomicCapture.t).v.statement;
-  const Fortran::evaluate::Assignment &assign2 = *stmt2.typedAssignment->v;
   const auto &stmt2Var{std::get<Fortran::parser::Variable>(stmt2.t)};
   const auto &stmt2Expr{std::get<Fortran::parser::Expr>(stmt2.t)};
 
@@ -498,25 +526,37 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
   mlir::Value stmt1LHSArg, stmt1RHSArg, stmt2LHSArg, stmt2RHSArg;
   mlir::Type elementType;
   // LHS evaluations are common to all combinations of `atomic.capture`
-  stmt1LHSArg = fir::getBase(converter.genExprAddr(assign1.lhs, stmtCtx));
-  stmt2LHSArg = fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
+  stmt1LHSArg = fir::getBase(
+      converter.genExprAddr(*Fortran::semantics::GetExpr(stmt1Var), stmtCtx));
+  stmt2LHSArg = fir::getBase(
+      converter.genExprAddr(*Fortran::semantics::GetExpr(stmt2Var), stmtCtx));
 
   // Operation specific RHS evaluations
   if (checkForSingleVariableOnRHS(stmt1)) {
     // Atomic capture construct is of the form [capture-stmt, update-stmt] or
     // of the form [capture-stmt, write-stmt]
-    stmt1RHSArg = fir::getBase(converter.genExprAddr(assign1.rhs, stmtCtx));
+    stmt1RHSArg = fir::getBase(converter.genExprAddr(
+        *Fortran::semantics::GetExpr(stmt1Expr), stmtCtx));
+    // To handle type convert for atomic write/update.
+    const Fortran::evaluate::Assignment &assign2 = *stmt2.typedAssignment->v;
     stmt2RHSArg = fir::getBase(converter.genExprValue(assign2.rhs, stmtCtx));
   } else {
     // Atomic capture construct is of the form [update-stmt, capture-stmt]
+    // To handle type convert for atomic update.
+    const Fortran::evaluate::Assignment &assign1 = *stmt1.typedAssignment->v;
     stmt1RHSArg = fir::getBase(converter.genExprValue(assign1.rhs, stmtCtx));
-    stmt2RHSArg = fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
+    stmt2RHSArg = fir::getBase(converter.genExprAddr(
+        *Fortran::semantics::GetExpr(stmt2Expr), stmtCtx));
   }
   // Type information used in generation of `atomic.update` operation
   mlir::Type stmt1VarType =
-      fir::getBase(converter.genExprValue(assign1.lhs, stmtCtx)).getType();
+      fir::getBase(converter.genExprValue(
+                       *Fortran::semantics::GetExpr(stmt1Var), stmtCtx))
+          .getType();
   mlir::Type stmt2VarType =
-      fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType();
+      fir::getBase(converter.genExprValue(
+                       *Fortran::semantics::GetExpr(stmt2Var), stmtCtx))
+          .getType();
 
   mlir::Operation *atomicCaptureOp = nullptr;
   if constexpr (std::is_same<AtomicListT,
@@ -547,7 +587,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
       genOmpAccAtomicCaptureStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt1LHSArg,
           /*leftHandClauseList=*/nullptr,
-          /*rightHandClauseList=*/nullptr, elementType, loc);
+          /*rightHandClauseList=*/nullptr, elementType, loc, atomicCaptureOp);
       genOmpAccAtomicUpdateStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt2VarType, stmt2Var, stmt2Expr,
           /*leftHandClauseList=*/nullptr,
@@ -560,7 +600,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
       genOmpAccAtomicCaptureStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt1LHSArg,
           /*leftHandClauseList=*/nullptr,
-          /*rightHandClauseList=*/nullptr, elementType, loc);
+          /*rightHandClauseList=*/nullptr, elementType, loc, atomicCaptureOp);
       genOmpAccAtomicWriteStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt2RHSArg,
           /*leftHandClauseList=*/nullptr,
@@ -575,7 +615,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
     genOmpAccAtomicCaptureStatement<AtomicListT>(
         converter, stmt1LHSArg, stmt2LHSArg,
         /*leftHandClauseList=*/nullptr,
-        /*rightHandClauseList=*/nullptr, elementType, loc);
+        /*rightHandClauseList=*/nullptr, elementType, loc, atomicCaptureOp);
     firOpBuilder.setInsertionPointToStart(&block);
     genOmpAccAtomicUpdateStatement<AtomicListT>(
         converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
diff --git a/flang/test/Lower/OpenACC/acc-atomic-read.f90 b/flang/test/Lower/OpenACC/acc-atomic-read.f90
index c1a97a9e5f74f..5c59c86236d4a 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-read.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-read.f90
@@ -55,5 +55,9 @@ subroutine atomic_read_with_convert()
 ! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {uniq_name = "_QFatomic_read_with_convertEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 ! CHECK: %[[Y:.*]] = fir.alloca i64 {bindc_name = "y", uniq_name = "_QFatomic_read_with_convertEy"}
 ! CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {uniq_name = "_QFatomic_read_with_convertEy"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
-! CHECK: %[[CONV:.*]] = fir.convert %[[X_DECL]]#1 : (!fir.ref<i32>) -> !fir.ref<i64>
-! CHECK: acc.atomic.read %[[Y_DECL]]#1 = %[[CONV]] : !fir.ref<i64>, i32
+! CHECK: %[[TEMP:.*]] = fir.alloca i32
+! CHECK: %[[TEMP_DECL:.*]]:2 = hlfir.declare %[[TEMP]] {uniq_name = ".atomic.read.temp"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: acc.atomic.read %[[TEMP_DECL]]#0 = %1#1 : !fir.ref<i32>, i32
+! CHECK: %[[TEMP_LD:.*]] = fir.load %[[TEMP_DECL]]#0 : !fir.ref<i32>
+! CHECK: %[[TEMP_CVT:.*]] = fir.convert %[[TEMP_LD]] : (i32) -> i64
+! CHECK: fir.store %[[TEMP_CVT]] to %[[Y_DECL]]#1 : !fir.ref<i64>
diff --git a/flang/test/Lower/OpenMP/atomic-capture.f90 b/flang/test/Lower/OpenMP/atomic-capture.f90
index 32d8cd7bbf328..6489a560b77b0 100644
--- a/flang/test/Lower/OpenMP/atomic-capture.f90
+++ b/flang/test/Lower/OpenMP/atomic-capture.f90
@@ -97,3 +97,33 @@ subroutine pointers_in_atomic_capture()
         b = a
     !$omp end atomic
 end subroutine
+
+! CHECK-LABEL:   func.func @_QPcapture_with_convert() {
+! CHECK:           %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "c", uniq_name = "_QFcapture_with_convertEc"}
+! CHECK:           %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFcapture_with_convertEc"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK:           %[[VAL_2:.*]] = fir.alloca f64 {bindc_name = "c2", uniq_name = "_QFcapture_with_convertEc2"}
+! CHECK:           %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2]] {uniq_name = "_QFcapture_with_convertEc2"} : (!fir.ref<f64>) -> (!fir.ref<f64>, !fir.ref<f64>)
+! CHECK:           %[[VAL_4:.*]] = fir.alloca f32
+! CHECK:           %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = ".atomic.read.temp"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK:           %[[VAL_6:.*]] = arith.constant 2.000000e+00 : f32
+! CHECK:           omp.atomic.capture {
+! CHECK:             omp.atomic.read %[[VAL_5]]#0 = %[[VAL_1]]#1 : !fir.ref<f32>, f32
+! CHECK:             omp.atomic.update %[[VAL_1]]#1 : !fir.ref<f32> {
+! CHECK:             ^bb0(%[[VAL_7:.*]]: f32):
+! CHECK:               %[[VAL_8:.*]] = arith.mulf %[[VAL_6]], %[[VAL_7]] fastmath<contract> : f32
+! CHECK:               omp.yield(%[[VAL_8]] : f32)
+! CHECK:             }
+! CHECK:           }
+! CHECK:           %[[VAL_9:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<f32>
+! CHECK:           %[[VAL_10:.*]] = fir.convert %[[VAL_9]] : (f32) -> f64
+! CHECK:           fir.store %[[VAL_10]] to %[[VAL_3]]#1 : !fir.ref<f64>
+! CHECK:           return
+! CHECK:         }
+subroutine capture_with_convert()
+    real :: c
+    double precision :: c2
+!$omp atomic capture
+    c2 = c
+    c = 2.0 * c
+!$omp end atomic
+end
diff --git a/flang/test/Lower/OpenMP/atomic-read.f90 b/flang/test/Lower/OpenMP/atomic-read.f90
index 8c3f37c94975e..940c0d61d91ca 100644
--- a/flang/test/Lower/OpenMP/atomic-read.f90
+++ b/flang/test/Lower/OpenMP/atomic-read.f90
@@ -89,3 +89,40 @@ subroutine atomic_read_pointer()
   x = y
 end
 
+! CHECK-LABEL:   func.func @_QPread_with_convert() {
+! CHECK:           %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "a", uniq_name = "_QFread_with_convertEa"}
+! CHECK:           %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFread_with_convertEa"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK:           %[[VAL_2:.*]] = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFread_with_convertEb"}
+! CHECK:           %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2]] {uniq_name = "_QFread_with_convertEb"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK:           %[[VAL_4:.*]] = fir.alloca i32
+! CHECK:           %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = ".atomic.read.temp"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK:           omp.atomic.read %[[VAL_5]]#0 = %[[VAL_3]]#1 : !fir.ref<i32>, i32
+! CHECK:           %[[VAL_6:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
+! CHECK:           %[[VAL_7:.*]] = fir.convert %[[VAL_6]] : (i32) -> f32
+! CHECK:           fir.store %[[VAL_7]] to %[[VAL_1]]#1 : !fir.ref<f32>
+subroutine read_with_convert()
+   real :: a
+   integer :: b
+   !$omp atomic read
+   a = b
+end
+
+! CHECK-LABEL:   func.func @_QPread_complex_with_convert() {
+! CHECK:           %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "s_v_r2", uniq_name = "_QFread_complex_with_convertEs_v_r2"}
+! CHECK:           %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFread_complex_with_convertEs_v_r2"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK:           %[[VAL_2:.*]] = fir.alloca !fir.complex<4> {bindc_name = "s_x_c2", uniq_name = "_QFread_complex_with_convertEs_x_c2"}
+! CHECK:           %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2]] {uniq_name = "_QFread_complex_with_convertEs_x_c2"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
+! CHECK:           %[[VAL_4:.*]] = fir.alloca !fir.complex<4>
+! CHECK:           %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = ".atomic.read.temp"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
+! CHECK:           omp.atomic.read %[[VAL_5]]#0 = %[[VAL_3]]#1 : !fir.ref<!fir.complex<4>>, !fir.complex<4>
+! CHECK:           %[[VAL_6:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<!fir.complex<4>>
+! CHECK:           %[[VAL_7:.*]] = fir.extract_value %[[VAL_6]], [0 : index] : (!fir.complex<4>) -> f32
+! CHECK:           %[[VAL_8:.*]] = fir.convert %[[VAL_7]] : (f32) -> f32
+! CHECK:           fir.store %[[VAL_8]] to %[[VAL_1]]#1 : !fir.ref<f32>
+subroutine read_complex_with_convert()
+   real(kind=4)    :: s_v_r2
+   complex(kind=4) :: s_x_c2
+ !$omp atomic read
+    s_v_r2 = s_x_c2
+ !$omp end atomic
+end



More information about the flang-commits mailing list