[flang-commits] [flang] [Flang][OpenMP][OpenACC] Handle atomic read/capture when lhs and rhs … (PR #93776)
via flang-commits
flang-commits at lists.llvm.org
Thu May 30 00:17:31 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp
Author: None (harishch4)
<details>
<summary>Changes</summary>
…types are different
Fixes : #<!-- -->83722
Changed evaluated expression to typed expression for atomic reads to keep it consistent with clang. Atomic loads now happen on the actual memory location rather than a converted value.
Handled generating hlfir for complex types as well, but lowering it to LLVM IR is still a WIP (That should fix #<!-- -->93441).
---
Full diff: https://github.com/llvm/llvm-project/pull/93776.diff
4 Files Affected:
- (modified) flang/lib/Lower/DirectivesCommon.h (+56-16)
- (modified) flang/test/Lower/OpenACC/acc-atomic-read.f90 (+6-2)
- (modified) flang/test/Lower/OpenMP/atomic-capture.f90 (+30)
- (modified) flang/test/Lower/OpenMP/atomic-read.f90 (+37)
``````````diff
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 48b090f6d2dbe..d97ae2c5f51f4 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -30,6 +30,7 @@
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/Support/Utils.h"
#include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/Complex.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/Todo.h"
@@ -143,9 +144,24 @@ static inline void genOmpAccAtomicCaptureStatement(
mlir::Value toAddress,
[[maybe_unused]] const AtomicListT *leftHandClauseList,
[[maybe_unused]] const AtomicListT *rightHandClauseList,
- mlir::Type elementType, mlir::Location loc) {
+ mlir::Type elementType, mlir::Location loc,
+ mlir::Operation *atomicCaptureOp = nullptr) {
// Generate `atomic.read` operation for atomic assigment statements
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ mlir::Value oldToAddress = toAddress;
+ if (fromAddress.getType() != oldToAddress.getType()) {
+ auto insertionPoint = firOpBuilder.saveInsertionPoint();
+ if (atomicCaptureOp)
+ firOpBuilder.setInsertionPoint(atomicCaptureOp);
+ auto alloca = firOpBuilder.create<fir::AllocaOp>(loc, elementType);
+ auto declareOp = firOpBuilder.create<hlfir::DeclareOp>(
+ loc, alloca, ".atomic.read.temp", /*shape=*/nullptr,
+ llvm::ArrayRef<mlir::Value>{},
+ /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{});
+ toAddress = declareOp.getBase();
+ if (atomicCaptureOp)
+ firOpBuilder.restoreInsertionPoint(insertionPoint);
+ }
if constexpr (std::is_same<AtomicListT,
Fortran::parser::OmpAtomicClauseList>()) {
@@ -167,6 +183,24 @@ static inline void genOmpAccAtomicCaptureStatement(
firOpBuilder.create<mlir::acc::AtomicReadOp>(
loc, fromAddress, toAddress, mlir::TypeAttr::get(elementType));
}
+
+ if (fromAddress.getType() != oldToAddress.getType()) {
+ auto insertionPoint = firOpBuilder.saveInsertionPoint();
+ if (atomicCaptureOp)
+ firOpBuilder.setInsertionPointAfter(atomicCaptureOp);
+ mlir::Value load = firOpBuilder.create<fir::LoadOp>(loc, toAddress);
+ if (auto cmplxTy = mlir::dyn_cast_or_null<fir::ComplexType>(elementType)) {
+ mlir::Value extractValue =
+ fir::factory::Complex{firOpBuilder, loc}.extractComplexPart(load,
+ false);
+ load = extractValue;
+ }
+ mlir::Value convert = firOpBuilder.create<fir::ConvertOp>(
+ loc, fir::unwrapRefType(oldToAddress.getType()), load);
+ firOpBuilder.create<fir::StoreOp>(loc, convert, oldToAddress);
+ if (atomicCaptureOp)
+ firOpBuilder.restoreInsertionPoint(insertionPoint);
+ }
}
/// Used to generate atomic.write operation which is created in existing
@@ -408,10 +442,6 @@ void genOmpAccAtomicRead(Fortran::lower::AbstractConverter &converter,
fir::getBase(converter.genExprAddr(fromExpr, stmtCtx));
mlir::Value toAddress = fir::getBase(converter.genExprAddr(
*Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
- fir::FirOpBuilder &builder = converter.getFirOpBuilder();
- if (fromAddress.getType() != toAddress.getType())
- fromAddress =
- builder.create<fir::ConvertOp>(loc, toAddress.getType(), fromAddress);
genOmpAccAtomicCaptureStatement(converter, fromAddress, toAddress,
leftHandClauseList, rightHandClauseList,
elementType, loc);
@@ -481,12 +511,10 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::AssignmentStmt &stmt1 =
std::get<typename AtomicT::Stmt1>(atomicCapture.t).v.statement;
- const Fortran::evaluate::Assignment &assign1 = *stmt1.typedAssignment->v;
const auto &stmt1Var{std::get<Fortran::parser::Variable>(stmt1.t)};
const auto &stmt1Expr{std::get<Fortran::parser::Expr>(stmt1.t)};
const Fortran::parser::AssignmentStmt &stmt2 =
std::get<typename AtomicT::Stmt2>(atomicCapture.t).v.statement;
- const Fortran::evaluate::Assignment &assign2 = *stmt2.typedAssignment->v;
const auto &stmt2Var{std::get<Fortran::parser::Variable>(stmt2.t)};
const auto &stmt2Expr{std::get<Fortran::parser::Expr>(stmt2.t)};
@@ -498,25 +526,37 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
mlir::Value stmt1LHSArg, stmt1RHSArg, stmt2LHSArg, stmt2RHSArg;
mlir::Type elementType;
// LHS evaluations are common to all combinations of `atomic.capture`
- stmt1LHSArg = fir::getBase(converter.genExprAddr(assign1.lhs, stmtCtx));
- stmt2LHSArg = fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
+ stmt1LHSArg = fir::getBase(
+ converter.genExprAddr(*Fortran::semantics::GetExpr(stmt1Var), stmtCtx));
+ stmt2LHSArg = fir::getBase(
+ converter.genExprAddr(*Fortran::semantics::GetExpr(stmt2Var), stmtCtx));
// Operation specific RHS evaluations
if (checkForSingleVariableOnRHS(stmt1)) {
// Atomic capture construct is of the form [capture-stmt, update-stmt] or
// of the form [capture-stmt, write-stmt]
- stmt1RHSArg = fir::getBase(converter.genExprAddr(assign1.rhs, stmtCtx));
+ stmt1RHSArg = fir::getBase(converter.genExprAddr(
+ *Fortran::semantics::GetExpr(stmt1Expr), stmtCtx));
+ // To handle type convert for atomic write/update.
+ const Fortran::evaluate::Assignment &assign2 = *stmt2.typedAssignment->v;
stmt2RHSArg = fir::getBase(converter.genExprValue(assign2.rhs, stmtCtx));
} else {
// Atomic capture construct is of the form [update-stmt, capture-stmt]
+ // To handle type convert for atomic update.
+ const Fortran::evaluate::Assignment &assign1 = *stmt1.typedAssignment->v;
stmt1RHSArg = fir::getBase(converter.genExprValue(assign1.rhs, stmtCtx));
- stmt2RHSArg = fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
+ stmt2RHSArg = fir::getBase(converter.genExprAddr(
+ *Fortran::semantics::GetExpr(stmt2Expr), stmtCtx));
}
// Type information used in generation of `atomic.update` operation
mlir::Type stmt1VarType =
- fir::getBase(converter.genExprValue(assign1.lhs, stmtCtx)).getType();
+ fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(stmt1Var), stmtCtx))
+ .getType();
mlir::Type stmt2VarType =
- fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType();
+ fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(stmt2Var), stmtCtx))
+ .getType();
mlir::Operation *atomicCaptureOp = nullptr;
if constexpr (std::is_same<AtomicListT,
@@ -547,7 +587,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
genOmpAccAtomicCaptureStatement<AtomicListT>(
converter, stmt1RHSArg, stmt1LHSArg,
/*leftHandClauseList=*/nullptr,
- /*rightHandClauseList=*/nullptr, elementType, loc);
+ /*rightHandClauseList=*/nullptr, elementType, loc, atomicCaptureOp);
genOmpAccAtomicUpdateStatement<AtomicListT>(
converter, stmt1RHSArg, stmt2VarType, stmt2Var, stmt2Expr,
/*leftHandClauseList=*/nullptr,
@@ -560,7 +600,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
genOmpAccAtomicCaptureStatement<AtomicListT>(
converter, stmt1RHSArg, stmt1LHSArg,
/*leftHandClauseList=*/nullptr,
- /*rightHandClauseList=*/nullptr, elementType, loc);
+ /*rightHandClauseList=*/nullptr, elementType, loc, atomicCaptureOp);
genOmpAccAtomicWriteStatement<AtomicListT>(
converter, stmt1RHSArg, stmt2RHSArg,
/*leftHandClauseList=*/nullptr,
@@ -575,7 +615,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
genOmpAccAtomicCaptureStatement<AtomicListT>(
converter, stmt1LHSArg, stmt2LHSArg,
/*leftHandClauseList=*/nullptr,
- /*rightHandClauseList=*/nullptr, elementType, loc);
+ /*rightHandClauseList=*/nullptr, elementType, loc, atomicCaptureOp);
firOpBuilder.setInsertionPointToStart(&block);
genOmpAccAtomicUpdateStatement<AtomicListT>(
converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
diff --git a/flang/test/Lower/OpenACC/acc-atomic-read.f90 b/flang/test/Lower/OpenACC/acc-atomic-read.f90
index c1a97a9e5f74f..5c59c86236d4a 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-read.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-read.f90
@@ -55,5 +55,9 @@ subroutine atomic_read_with_convert()
! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {uniq_name = "_QFatomic_read_with_convertEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[Y:.*]] = fir.alloca i64 {bindc_name = "y", uniq_name = "_QFatomic_read_with_convertEy"}
! CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {uniq_name = "_QFatomic_read_with_convertEy"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
-! CHECK: %[[CONV:.*]] = fir.convert %[[X_DECL]]#1 : (!fir.ref<i32>) -> !fir.ref<i64>
-! CHECK: acc.atomic.read %[[Y_DECL]]#1 = %[[CONV]] : !fir.ref<i64>, i32
+! CHECK: %[[TEMP:.*]] = fir.alloca i32
+! CHECK: %[[TEMP_DECL:.*]]:2 = hlfir.declare %[[TEMP]] {uniq_name = ".atomic.read.temp"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: acc.atomic.read %[[TEMP_DECL]]#0 = %1#1 : !fir.ref<i32>, i32
+! CHECK: %[[TEMP_LD:.*]] = fir.load %[[TEMP_DECL]]#0 : !fir.ref<i32>
+! CHECK: %[[TEMP_CVT:.*]] = fir.convert %[[TEMP_LD]] : (i32) -> i64
+! CHECK: fir.store %[[TEMP_CVT]] to %[[Y_DECL]]#1 : !fir.ref<i64>
diff --git a/flang/test/Lower/OpenMP/atomic-capture.f90 b/flang/test/Lower/OpenMP/atomic-capture.f90
index 32d8cd7bbf328..6489a560b77b0 100644
--- a/flang/test/Lower/OpenMP/atomic-capture.f90
+++ b/flang/test/Lower/OpenMP/atomic-capture.f90
@@ -97,3 +97,33 @@ subroutine pointers_in_atomic_capture()
b = a
!$omp end atomic
end subroutine
+
+! CHECK-LABEL: func.func @_QPcapture_with_convert() {
+! CHECK: %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "c", uniq_name = "_QFcapture_with_convertEc"}
+! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFcapture_with_convertEc"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK: %[[VAL_2:.*]] = fir.alloca f64 {bindc_name = "c2", uniq_name = "_QFcapture_with_convertEc2"}
+! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2]] {uniq_name = "_QFcapture_with_convertEc2"} : (!fir.ref<f64>) -> (!fir.ref<f64>, !fir.ref<f64>)
+! CHECK: %[[VAL_4:.*]] = fir.alloca f32
+! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = ".atomic.read.temp"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK: %[[VAL_6:.*]] = arith.constant 2.000000e+00 : f32
+! CHECK: omp.atomic.capture {
+! CHECK: omp.atomic.read %[[VAL_5]]#0 = %[[VAL_1]]#1 : !fir.ref<f32>, f32
+! CHECK: omp.atomic.update %[[VAL_1]]#1 : !fir.ref<f32> {
+! CHECK: ^bb0(%[[VAL_7:.*]]: f32):
+! CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_6]], %[[VAL_7]] fastmath<contract> : f32
+! CHECK: omp.yield(%[[VAL_8]] : f32)
+! CHECK: }
+! CHECK: }
+! CHECK: %[[VAL_9:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<f32>
+! CHECK: %[[VAL_10:.*]] = fir.convert %[[VAL_9]] : (f32) -> f64
+! CHECK: fir.store %[[VAL_10]] to %[[VAL_3]]#1 : !fir.ref<f64>
+! CHECK: return
+! CHECK: }
+subroutine capture_with_convert()
+ real :: c
+ double precision :: c2
+!$omp atomic capture
+ c2 = c
+ c = 2.0 * c
+!$omp end atomic
+end
diff --git a/flang/test/Lower/OpenMP/atomic-read.f90 b/flang/test/Lower/OpenMP/atomic-read.f90
index 8c3f37c94975e..940c0d61d91ca 100644
--- a/flang/test/Lower/OpenMP/atomic-read.f90
+++ b/flang/test/Lower/OpenMP/atomic-read.f90
@@ -89,3 +89,40 @@ subroutine atomic_read_pointer()
x = y
end
+! CHECK-LABEL: func.func @_QPread_with_convert() {
+! CHECK: %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "a", uniq_name = "_QFread_with_convertEa"}
+! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFread_with_convertEa"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK: %[[VAL_2:.*]] = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFread_with_convertEb"}
+! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2]] {uniq_name = "_QFread_with_convertEb"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[VAL_4:.*]] = fir.alloca i32
+! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = ".atomic.read.temp"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: omp.atomic.read %[[VAL_5]]#0 = %[[VAL_3]]#1 : !fir.ref<i32>, i32
+! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
+! CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_6]] : (i32) -> f32
+! CHECK: fir.store %[[VAL_7]] to %[[VAL_1]]#1 : !fir.ref<f32>
+subroutine read_with_convert()
+ real :: a
+ integer :: b
+ !$omp atomic read
+ a = b
+end
+
+! CHECK-LABEL: func.func @_QPread_complex_with_convert() {
+! CHECK: %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "s_v_r2", uniq_name = "_QFread_complex_with_convertEs_v_r2"}
+! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFread_complex_with_convertEs_v_r2"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK: %[[VAL_2:.*]] = fir.alloca !fir.complex<4> {bindc_name = "s_x_c2", uniq_name = "_QFread_complex_with_convertEs_x_c2"}
+! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2]] {uniq_name = "_QFread_complex_with_convertEs_x_c2"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
+! CHECK: %[[VAL_4:.*]] = fir.alloca !fir.complex<4>
+! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = ".atomic.read.temp"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
+! CHECK: omp.atomic.read %[[VAL_5]]#0 = %[[VAL_3]]#1 : !fir.ref<!fir.complex<4>>, !fir.complex<4>
+! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<!fir.complex<4>>
+! CHECK: %[[VAL_7:.*]] = fir.extract_value %[[VAL_6]], [0 : index] : (!fir.complex<4>) -> f32
+! CHECK: %[[VAL_8:.*]] = fir.convert %[[VAL_7]] : (f32) -> f32
+! CHECK: fir.store %[[VAL_8]] to %[[VAL_1]]#1 : !fir.ref<f32>
+subroutine read_complex_with_convert()
+ real(kind=4) :: s_v_r2
+ complex(kind=4) :: s_x_c2
+ !$omp atomic read
+ s_v_r2 = s_x_c2
+ !$omp end atomic
+end
``````````
</details>
https://github.com/llvm/llvm-project/pull/93776
More information about the flang-commits
mailing list