[flang-commits] [flang] [flang][OpenMP] Add implicit casts for omp.atomic.capture (PR #138163)
via flang-commits
flang-commits at lists.llvm.org
Thu May 1 09:24:19 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: None (NimishMishra)
<details>
<summary>Changes</summary>
This patch adds support for emitting implicit casts for atomic capture if its constituent operations have different yet compatible types.
Fixes: https://github.com/llvm/llvm-project/issues/138123
---
Patch is 21.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138163.diff
4 Files Affected:
- (modified) flang/docs/OpenMPSupport.md (+1-1)
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+161-75)
- (removed) flang/test/Lower/OpenMP/Todo/atomic-capture-implicit-cast.f90 (-48)
- (modified) flang/test/Lower/OpenMP/atomic-implicit-cast.f90 (+78)
``````````diff
diff --git a/flang/docs/OpenMPSupport.md b/flang/docs/OpenMPSupport.md
index 2d4b9dd737777..46be14f4c168c 100644
--- a/flang/docs/OpenMPSupport.md
+++ b/flang/docs/OpenMPSupport.md
@@ -64,4 +64,4 @@ Note : No distinction is made between the support in Parser/Semantics, MLIR, Low
| target teams distribute parallel loop simd construct | P | device, reduction, dist_schedule and linear clauses are not supported |
## OpenMP 3.1, OpenMP 2.5, OpenMP 1.1
-All features except a few corner cases in atomic (complex type, different but compatible types in lhs and rhs), threadprivate (character type) constructs/clauses are supported.
+All features except a few corner cases in threadprivate (character type) constructs/clauses are supported.
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 47e7c266ff7d3..526148855b113 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2865,6 +2865,85 @@ static void genAtomicWrite(lower::AbstractConverter &converter,
rightHandClauseList, loc);
}
+/*
+ Emit an implicit cast. Different yet compatible types on
+ omp.atomic.read constitute valid Fortran. The OMPIRBuilder will
+ emit atomic instructions (on primitive types) and `__atomic_load`
+ libcall (on complex type) without explicitly converting
+ between such compatible types. The OMPIRBuilder relies on the
+ frontend to resolve such inconsistencies between `omp.atomic.read `
+ operand types. Similar inconsistencies between operand types in
+ `omp.atomic.write` are resolved through implicit casting by use of typed
+ assignment (i.e. `evaluate::Assignment`). However, use of typed
+ assignment in `omp.atomic.read` (of form `v = x`) leads to an unsafe,
+ non-atomic load of `x` into a temporary `alloca`, followed by an atomic
+ read of form `v = alloca`. Hence, it is needed to perform a custom
+ implicit cast.
+
+ An atomic read of form `v = x` would (without implicit casting)
+ lower to `omp.atomic.read %v = %x : !fir.ref<type1>, !fir.ref<type2>,
+ type2`. This implicit casting will rather generate the following FIR:
+
+ %alloca = fir.alloca type2
+ omp.atomic.read %alloca = %x : !fir.ref<type2>, !fir.ref<type2>, type2
+ %load = fir.load %alloca : !fir.ref<type2>
+ %cvt = fir.convert %load : (type2) -> type1
+ fir.store %cvt to %v : !fir.ref<type1>
+
+ These sequence of operations is thread-safe since each thread allocates
+ the `alloca` in its stack, and performs `%alloca = %x` atomically. Once
+ safely read, each thread performs the implicit cast on the local
+ `alloca`, and writes the final result to `%v`.
+
+/// \param builder : FirOpBuilder
+/// \param loc : Location for FIR generation
+/// \param toAddress : Address of %v
+/// \param toType : Type of %v
+/// \param fromType : Type of %x
+/// \param alloca : Thread scoped `alloca`
+// It is the responsibility of the callee
+// to position the `alloca` at `AllocaIP`
+// through `builder.getAllocaBlock()`
+*/
+
+static void emitAtomicReadImplicitCast(fir::FirOpBuilder &builder,
+ mlir::Location loc,
+ mlir::Value toAddress, mlir::Type toType,
+ mlir::Type fromType,
+ mlir::Value alloca) {
+ auto load = builder.create<fir::LoadOp>(loc, alloca);
+ if (fir::isa_complex(fromType) && !fir::isa_complex(toType)) {
+ // Emit an additional `ExtractValueOp` if `fromAddress` is of complex
+ // type, but `toAddress` is not.
+ auto extract = builder.create<fir::ExtractValueOp>(
+ loc, mlir::cast<mlir::ComplexType>(fromType).getElementType(), load,
+ builder.getArrayAttr(
+ builder.getIntegerAttr(builder.getIndexType(), 0)));
+ auto cvt = builder.create<fir::ConvertOp>(loc, toType, extract);
+ builder.create<fir::StoreOp>(loc, cvt, toAddress);
+ } else if (!fir::isa_complex(fromType) && fir::isa_complex(toType)) {
+ // Emit an additional `InsertValueOp` if `toAddress` is of complex
+ // type, but `fromAddress` is not.
+ mlir::Value undef = builder.create<fir::UndefOp>(loc, toType);
+ mlir::Type complexEleTy =
+ mlir::cast<mlir::ComplexType>(toType).getElementType();
+ mlir::Value cvt = builder.create<fir::ConvertOp>(loc, complexEleTy, load);
+ mlir::Value zero = builder.createRealZeroConstant(loc, complexEleTy);
+ mlir::Value idx0 = builder.create<fir::InsertValueOp>(
+ loc, toType, undef, cvt,
+ builder.getArrayAttr(
+ builder.getIntegerAttr(builder.getIndexType(), 0)));
+ mlir::Value idx1 = builder.create<fir::InsertValueOp>(
+ loc, toType, idx0, zero,
+ builder.getArrayAttr(
+ builder.getIntegerAttr(builder.getIndexType(), 1)));
+ builder.create<fir::StoreOp>(loc, idx1, toAddress);
+ } else {
+ auto cvt = builder.create<fir::ConvertOp>(loc, toType, load);
+ builder.create<fir::StoreOp>(loc, cvt, toAddress);
+ }
+}
+
/// Processes an atomic construct with read clause.
static void genAtomicRead(lower::AbstractConverter &converter,
const parser::OmpAtomicRead &atomicRead,
@@ -2891,34 +2970,7 @@ static void genAtomicRead(lower::AbstractConverter &converter,
*semantics::GetExpr(assignmentStmtVariable), stmtCtx));
if (fromAddress.getType() != toAddress.getType()) {
- // Emit an implicit cast. Different yet compatible types on
- // omp.atomic.read constitute valid Fortran. The OMPIRBuilder will
- // emit atomic instructions (on primitive types) and `__atomic_load`
- // libcall (on complex type) without explicitly converting
- // between such compatible types. The OMPIRBuilder relies on the
- // frontend to resolve such inconsistencies between `omp.atomic.read `
- // operand types. Similar inconsistencies between operand types in
- // `omp.atomic.write` are resolved through implicit casting by use of typed
- // assignment (i.e. `evaluate::Assignment`). However, use of typed
- // assignment in `omp.atomic.read` (of form `v = x`) leads to an unsafe,
- // non-atomic load of `x` into a temporary `alloca`, followed by an atomic
- // read of form `v = alloca`. Hence, it is needed to perform a custom
- // implicit cast.
-
- // An atomic read of form `v = x` would (without implicit casting)
- // lower to `omp.atomic.read %v = %x : !fir.ref<type1>, !fir.ref<type2>,
- // type2`. This implicit casting will rather generate the following FIR:
- //
- // %alloca = fir.alloca type2
- // omp.atomic.read %alloca = %x : !fir.ref<type2>, !fir.ref<type2>, type2
- // %load = fir.load %alloca : !fir.ref<type2>
- // %cvt = fir.convert %load : (type2) -> type1
- // fir.store %cvt to %v : !fir.ref<type1>
-
- // These sequence of operations is thread-safe since each thread allocates
- // the `alloca` in its stack, and performs `%alloca = %x` atomically. Once
- // safely read, each thread performs the implicit cast on the local
- // `alloca`, and writes the final result to `%v`.
+
mlir::Type toType = fir::unwrapRefType(toAddress.getType());
mlir::Type fromType = fir::unwrapRefType(fromAddress.getType());
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
@@ -2930,37 +2982,8 @@ static void genAtomicRead(lower::AbstractConverter &converter,
genAtomicCaptureStatement(converter, fromAddress, alloca,
leftHandClauseList, rightHandClauseList,
elementType, loc);
- auto load = builder.create<fir::LoadOp>(loc, alloca);
- if (fir::isa_complex(fromType) && !fir::isa_complex(toType)) {
- // Emit an additional `ExtractValueOp` if `fromAddress` is of complex
- // type, but `toAddress` is not.
- auto extract = builder.create<fir::ExtractValueOp>(
- loc, mlir::cast<mlir::ComplexType>(fromType).getElementType(), load,
- builder.getArrayAttr(
- builder.getIntegerAttr(builder.getIndexType(), 0)));
- auto cvt = builder.create<fir::ConvertOp>(loc, toType, extract);
- builder.create<fir::StoreOp>(loc, cvt, toAddress);
- } else if (!fir::isa_complex(fromType) && fir::isa_complex(toType)) {
- // Emit an additional `InsertValueOp` if `toAddress` is of complex
- // type, but `fromAddress` is not.
- mlir::Value undef = builder.create<fir::UndefOp>(loc, toType);
- mlir::Type complexEleTy =
- mlir::cast<mlir::ComplexType>(toType).getElementType();
- mlir::Value cvt = builder.create<fir::ConvertOp>(loc, complexEleTy, load);
- mlir::Value zero = builder.createRealZeroConstant(loc, complexEleTy);
- mlir::Value idx0 = builder.create<fir::InsertValueOp>(
- loc, toType, undef, cvt,
- builder.getArrayAttr(
- builder.getIntegerAttr(builder.getIndexType(), 0)));
- mlir::Value idx1 = builder.create<fir::InsertValueOp>(
- loc, toType, idx0, zero,
- builder.getArrayAttr(
- builder.getIntegerAttr(builder.getIndexType(), 1)));
- builder.create<fir::StoreOp>(loc, idx1, toAddress);
- } else {
- auto cvt = builder.create<fir::ConvertOp>(loc, toType, load);
- builder.create<fir::StoreOp>(loc, cvt, toAddress);
- }
+ emitAtomicReadImplicitCast(builder, loc, toAddress, toType, fromType,
+ alloca);
} else
genAtomicCaptureStatement(converter, fromAddress, toAddress,
leftHandClauseList, rightHandClauseList,
@@ -3049,10 +3072,6 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
mlir::Type stmt2VarType =
fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType();
- // Check if implicit type is needed
- if (stmt1VarType != stmt2VarType)
- TODO(loc, "atomic capture requiring implicit type casts");
-
mlir::Operation *atomicCaptureOp = nullptr;
mlir::IntegerAttr hint = nullptr;
mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr;
@@ -3075,10 +3094,31 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
// Atomic capture construct is of the form [capture-stmt, update-stmt]
const semantics::SomeExpr &fromExpr = *semantics::GetExpr(stmt1Expr);
mlir::Type elementType = converter.genType(fromExpr);
- genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg,
- /*leftHandClauseList=*/nullptr,
- /*rightHandClauseList=*/nullptr, elementType,
- loc);
+ if (stmt1VarType != stmt2VarType) {
+ mlir::Value alloca;
+ mlir::Type toType = fir::unwrapRefType(stmt1LHSArg.getType());
+ mlir::Type fromType = fir::unwrapRefType(stmt2LHSArg.getType());
+ {
+ mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
+ firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
+ alloca = firOpBuilder.create<fir::AllocaOp>(loc, fromType);
+ }
+ genAtomicCaptureStatement(converter, stmt2LHSArg, alloca,
+ /*leftHandClauseList=*/nullptr,
+ /*rightHandClauseList=*/nullptr, elementType,
+ loc);
+ {
+ mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
+ firOpBuilder.setInsertionPointAfter(atomicCaptureOp);
+ emitAtomicReadImplicitCast(firOpBuilder, loc, stmt1LHSArg, toType,
+ fromType, alloca);
+ }
+ } else {
+ genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg,
+ /*leftHandClauseList=*/nullptr,
+ /*rightHandClauseList=*/nullptr, elementType,
+ loc);
+ }
genAtomicUpdateStatement(
converter, stmt2LHSArg, stmt2VarType, stmt2Var, stmt2Expr,
/*leftHandClauseList=*/nullptr,
@@ -3091,10 +3131,32 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
firOpBuilder.setInsertionPointToStart(&block);
const semantics::SomeExpr &fromExpr = *semantics::GetExpr(stmt1Expr);
mlir::Type elementType = converter.genType(fromExpr);
- genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg,
- /*leftHandClauseList=*/nullptr,
- /*rightHandClauseList=*/nullptr, elementType,
- loc);
+
+ if (stmt1VarType != stmt2VarType) {
+ mlir::Value alloca;
+ mlir::Type toType = fir::unwrapRefType(stmt1LHSArg.getType());
+ mlir::Type fromType = fir::unwrapRefType(stmt2LHSArg.getType());
+ {
+ mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
+ firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
+ alloca = firOpBuilder.create<fir::AllocaOp>(loc, fromType);
+ }
+ genAtomicCaptureStatement(converter, stmt2LHSArg, alloca,
+ /*leftHandClauseList=*/nullptr,
+ /*rightHandClauseList=*/nullptr, elementType,
+ loc);
+ {
+ mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
+ firOpBuilder.setInsertionPointAfter(atomicCaptureOp);
+ emitAtomicReadImplicitCast(firOpBuilder, loc, stmt1LHSArg, toType,
+ fromType, alloca);
+ }
+ } else {
+ genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg,
+ /*leftHandClauseList=*/nullptr,
+ /*rightHandClauseList=*/nullptr, elementType,
+ loc);
+ }
genAtomicWriteStatement(converter, stmt2LHSArg, stmt2RHSArg,
/*leftHandClauseList=*/nullptr,
/*rightHandClauseList=*/nullptr, loc);
@@ -3107,10 +3169,34 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
/*leftHandClauseList=*/nullptr,
/*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
- genAtomicCaptureStatement(converter, stmt1LHSArg, stmt2LHSArg,
- /*leftHandClauseList=*/nullptr,
- /*rightHandClauseList=*/nullptr, elementType,
- loc);
+
+ if (stmt1VarType != stmt2VarType) {
+ mlir::Value alloca;
+ mlir::Type toType = fir::unwrapRefType(stmt2LHSArg.getType());
+ mlir::Type fromType = fir::unwrapRefType(stmt1LHSArg.getType());
+
+ {
+ mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
+ firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
+ alloca = firOpBuilder.create<fir::AllocaOp>(loc, fromType);
+ }
+
+ genAtomicCaptureStatement(converter, stmt1LHSArg, alloca,
+ /*leftHandClauseList=*/nullptr,
+ /*rightHandClauseList=*/nullptr, elementType,
+ loc);
+ {
+ mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
+ firOpBuilder.setInsertionPointAfter(atomicCaptureOp);
+ emitAtomicReadImplicitCast(firOpBuilder, loc, stmt2LHSArg, toType,
+ fromType, alloca);
+ }
+ } else {
+ genAtomicCaptureStatement(converter, stmt1LHSArg, stmt2LHSArg,
+ /*leftHandClauseList=*/nullptr,
+ /*rightHandClauseList=*/nullptr, elementType,
+ loc);
+ }
}
firOpBuilder.setInsertionPointToEnd(&block);
firOpBuilder.create<mlir::omp::TerminatorOp>(loc);
diff --git a/flang/test/Lower/OpenMP/Todo/atomic-capture-implicit-cast.f90 b/flang/test/Lower/OpenMP/Todo/atomic-capture-implicit-cast.f90
deleted file mode 100644
index 5b61f1169308f..0000000000000
--- a/flang/test/Lower/OpenMP/Todo/atomic-capture-implicit-cast.f90
+++ /dev/null
@@ -1,48 +0,0 @@
-!RUN: %not_todo_cmd %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-!CHECK: not yet implemented: atomic capture requiring implicit type casts
-subroutine capture_with_convert_f32_to_i32()
- implicit none
- integer :: k, v, i
-
- k = 1
- v = 0
-
- !$omp atomic capture
- v = k
- k = (i + 1) * 3.14
- !$omp end atomic
-end subroutine
-
-subroutine capture_with_convert_i32_to_f64()
- real(8) :: x
- integer :: v
- x = 1.0
- v = 0
- !$omp atomic capture
- v = x
- x = v
- !$omp end atomic
-end subroutine capture_with_convert_i32_to_f64
-
-subroutine capture_with_convert_f64_to_i32()
- integer :: x
- real(8) :: v
- x = 1
- v = 0
- !$omp atomic capture
- x = v
- v = x
- !$omp end atomic
-end subroutine capture_with_convert_f64_to_i32
-
-subroutine capture_with_convert_i32_to_f32()
- real(4) :: x
- integer :: v
- x = 1.0
- v = 0
- !$omp atomic capture
- v = x
- x = x + v
- !$omp end atomic
-end subroutine capture_with_convert_i32_to_f32
diff --git a/flang/test/Lower/OpenMP/atomic-implicit-cast.f90 b/flang/test/Lower/OpenMP/atomic-implicit-cast.f90
index 75f1cbfc979b9..4c1be1ca91ac0 100644
--- a/flang/test/Lower/OpenMP/atomic-implicit-cast.f90
+++ b/flang/test/Lower/OpenMP/atomic-implicit-cast.f90
@@ -4,6 +4,10 @@
! CHECK: func.func @_QPatomic_implicit_cast_read() {
subroutine atomic_implicit_cast_read
+! CHECK: %[[ALLOCA7:.*]] = fir.alloca complex<f64>
+! CHECK: %[[ALLOCA6:.*]] = fir.alloca i32
+! CHECK: %[[ALLOCA5:.*]] = fir.alloca i32
+! CHECK: %[[ALLOCA4:.*]] = fir.alloca i32
! CHECK: %[[ALLOCA3:.*]] = fir.alloca complex<f32>
! CHECK: %[[ALLOCA2:.*]] = fir.alloca complex<f32>
! CHECK: %[[ALLOCA1:.*]] = fir.alloca i32
@@ -53,4 +57,78 @@ subroutine atomic_implicit_cast_read
! CHECK: fir.store %[[CVT]] to %[[M_DECL]]#0 : !fir.ref<complex<f64>>
!$omp atomic read
m = w
+
+! CHECK: %[[CONST:.*]] = arith.constant 1 : i32
+! CHECK: omp.atomic.capture {
+! CHECK: omp.atomic.read %[[ALLOCA4]] = %[[X_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK: omp.atomic.update %[[X_DECL]]#0 : !fir.ref<i32> {
+! CHECK: ^bb0(%[[ARG:.*]]: i32):
+! CHECK: %[[RESULT:.*]] = arith.addi %[[ARG]], %[[CONST]] : i32
+! CHECK: omp.yield(%[[RESULT]] : i32)
+! CHECK: }
+! CHECK: }
+! CHECK: %[[LOAD:.*]] = fir.load %[[ALLOCA4]] : !fir.ref<i32>
+! CHECK: %[[CVT:.*]] = fir.convert %[[LOAD]] : (i32) -> f32
+! CHECK: fir.store %[[CVT]] to %[[Y_DECL]]#0 : !fir.ref<f32>
+ !$omp atomic capture
+ y = x
+ x = x + 1
+ !$omp end atomic
+
+! CHECK: %[[CONST:.*]] = arith.constant 10 : i32
+! CHECK: omp.atomic.capture {
+! CHECK: omp.atomic.read %[[ALLOCA5:.*]] = %[[X_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK: omp.atomic.write %[[X_DECL]]#0 = %[[CONST]] : !fir.ref<i32>, i32
+! CHECK: }
+! CHECK: %[[LOAD:.*]] = fir.load %[[ALLOCA5]] : !fir.ref<i32>
+! CHECK: %[[CVT:.*]] = fir.convert %[[LOAD]] : (i32) -> f64
+! CHECK: fir.store %[[CVT]] to %[[Z_DECL]]#0 : !fir.ref<f64>
+ !$omp atomic capture
+ z = x
+ x = 10
+ !$omp end atomic
+
+! CHECK: %[[CONST:.*]] = arith.constant 1 : i32
+! CHECK: omp.atomic.capture {
+! CHECK: omp.atomic.update %[[X_DECL]]#0 : !fir.ref<i32> {
+! CHECK: ^bb0(%[[ARG:.*]]: i32):
+! CHECK: %[[RESULT:.*]] = arith.addi %[[ARG]], %[[CONST]] : i32
+! CHECK: omp.yield(%[[RESULT]] : i32)
+! CHECK: }
+! CHECK: omp.atomic.read %[[ALLOCA6]] = %[[X_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK: %[[LOAD:.*]] = fir.load %[[ALLOCA6]] : !fir.ref<i32>
+! CHECK: %[[UNDEF:.*]] = fir.undefined complex<f32>
+! CHECK: %[[CVT:.*]] = fir.convert %[[LOAD]] : (i32) -> f32
+! CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+! CHECK: %[[IDX1:.*]] = fir.insert_value %[[UNDEF]], %[[CVT]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+! CHECK: %[[IDX2:.*]] = fir.insert_value %[[IDX1]], %[[CST]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+! CHECK: fir.store %[[IDX2]] to %[[W_DECL]]#0 : !fir.ref<complex<f32>>
+ !$om...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/138163
More information about the flang-commits
mailing list