[flang-commits] [flang] [flang] Postpone hlfir.end_associate generation for calls. (PR #138786)
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Fri May 9 19:37:51 PDT 2025
https://github.com/vzakhari updated https://github.com/llvm/llvm-project/pull/138786
>From f65c8b369ce0e866996095c239293f0716608d11 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Tue, 6 May 2025 16:38:48 -0700
Subject: [PATCH 1/5] [flang] Postpone hlfir.end_associate generation for
calls.
If we generate hlfir.end_associate at the end of the statement,
we get easier optimizable HLFIR, because there are no compiler
generated operations with side-effects in between the call
and the consumers. This allows more hlfir.eval_in_mem to reuse
the LHS instead of allocating temporary buffer.
I do not think the same can be done for hlfir.copy_out always, e.g.:
```
subroutine test2(x)
interface
function array_func2(x,y)
real:: x(*), array_func2(10), y
end function array_func2
end interface
real :: x(:)
x = array_func2(x, 1.0)
end subroutine test2
```
If we postpone the copy-out until after the assignment, then
the result may be wrong.
---
flang/lib/Lower/ConvertCall.cpp | 42 +++++++--
.../Lower/HLFIR/call-postponed-associate.f90 | 85 +++++++++++++++++++
2 files changed, 121 insertions(+), 6 deletions(-)
create mode 100644 flang/test/Lower/HLFIR/call-postponed-associate.f90
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index a5b85e25b1af0..d37d51f6ec634 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -960,9 +960,26 @@ struct CallCleanUp {
mlir::Value tempVar;
mlir::Value mustFree;
};
- void genCleanUp(mlir::Location loc, fir::FirOpBuilder &builder) {
- Fortran::common::visit([&](auto &c) { c.genCleanUp(loc, builder); },
+
+ /// Generate clean-up code.
+ /// If \p postponeAssociates is true, the ExprAssociate clean-up
+ /// is not generated, and instead the corresponding CallCleanUp
+ /// object is returned as the result.
+ std::optional<CallCleanUp> genCleanUp(mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ bool postponeAssociates) {
+ std::optional<CallCleanUp> postponed;
+ Fortran::common::visit(Fortran::common::visitors{
+ [&](CopyIn &c) { c.genCleanUp(loc, builder); },
+ [&](ExprAssociate &c) {
+ if (postponeAssociates)
+ postponed = CallCleanUp{c};
+ else
+ c.genCleanUp(loc, builder);
+ },
+ },
cleanUp);
+ return postponed;
}
std::variant<CopyIn, ExprAssociate> cleanUp;
};
@@ -1729,10 +1746,23 @@ genUserCall(Fortran::lower::PreparedActualArguments &loweredActuals,
caller, callSiteType, callContext.resultType,
callContext.isElementalProcWithArrayArgs());
- /// Clean-up associations and copy-in.
- for (auto cleanUp : callCleanUps)
- cleanUp.genCleanUp(loc, builder);
-
+ // Clean-up associations and copy-in.
+ // The association clean-ups are postponed to the end of the statement
+ // lowering. The copy-in clean-ups may be delayed as well,
+ // but they are done immediately after the call currently.
+ llvm::SmallVector<CallCleanUp> associateCleanups;
+ for (auto cleanUp : callCleanUps) {
+ auto postponed =
+ cleanUp.genCleanUp(loc, builder, /*postponeAssociates=*/true);
+ if (postponed)
+ associateCleanups.push_back(*postponed);
+ }
+
+ fir::FirOpBuilder *bldr = &builder;
+ callContext.stmtCtx.attachCleanup([=]() {
+ for (auto cleanUp : associateCleanups)
+ (void)cleanUp.genCleanUp(loc, *bldr, /*postponeAssociates=*/false);
+ });
if (auto *entity = std::get_if<hlfir::EntityWithAttributes>(&loweredResult))
return *entity;
diff --git a/flang/test/Lower/HLFIR/call-postponed-associate.f90 b/flang/test/Lower/HLFIR/call-postponed-associate.f90
new file mode 100644
index 0000000000000..18df62b44324b
--- /dev/null
+++ b/flang/test/Lower/HLFIR/call-postponed-associate.f90
@@ -0,0 +1,85 @@
+! RUN: bbc -emit-hlfir -o - %s -I nowhere | FileCheck %s
+
+subroutine test1
+ interface
+ function array_func1(x)
+ real:: x, array_func1(10)
+ end function array_func1
+ end interface
+ real :: x(10)
+ x = array_func1(1.0)
+end subroutine test1
+! CHECK-LABEL: func.func @_QPtest1() {
+! CHECK: %[[VAL_5:.*]] = arith.constant 1.000000e+00 : f32
+! CHECK: %[[VAL_6:.*]]:3 = hlfir.associate %[[VAL_5]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
+! CHECK: %[[VAL_17:.*]] = hlfir.eval_in_mem shape %{{.*}} : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+! CHECK: fir.call @_QParray_func1
+! CHECK: fir.save_result
+! CHECK: }
+! CHECK: hlfir.assign %[[VAL_17]] to %{{.*}} : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+! CHECK: hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
+
+subroutine test2(x)
+ interface
+ function array_func2(x,y)
+ real:: x(*), array_func2(10), y
+ end function array_func2
+ end interface
+ real :: x(:)
+ x = array_func2(x, 1.0)
+end subroutine test2
+! CHECK-LABEL: func.func @_QPtest2(
+! CHECK: %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32
+! CHECK: %[[VAL_4:.*]]:2 = hlfir.copy_in %{{.*}} to %{{.*}} : (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.box<!fir.array<?xf32>>, i1)
+! CHECK: %[[VAL_5:.*]] = fir.box_addr %[[VAL_4]]#0 : (!fir.box<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
+! CHECK: %[[VAL_6:.*]]:3 = hlfir.associate %[[VAL_3]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
+! CHECK: %[[VAL_17:.*]] = hlfir.eval_in_mem shape %{{.*}} : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+! CHECK: ^bb0(%[[VAL_18:.*]]: !fir.ref<!fir.array<10xf32>>):
+! CHECK: %[[VAL_19:.*]] = fir.call @_QParray_func2(%[[VAL_5]], %[[VAL_6]]#0) fastmath<contract> : (!fir.ref<!fir.array<?xf32>>, !fir.ref<f32>) -> !fir.array<10xf32>
+! CHECK: fir.save_result %[[VAL_19]] to %[[VAL_18]](%{{.*}}) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+! CHECK: }
+! CHECK: hlfir.copy_out %{{.*}}, %[[VAL_4]]#1 to %{{.*}} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, i1, !fir.box<!fir.array<?xf32>>) -> ()
+! CHECK: hlfir.assign %[[VAL_17]] to %{{.*}} : !hlfir.expr<10xf32>, !fir.box<!fir.array<?xf32>>
+! CHECK: hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
+! CHECK: hlfir.destroy %[[VAL_17]] : !hlfir.expr<10xf32>
+
+subroutine test3(x)
+ interface
+ function array_func3(x)
+ real :: x, array_func3(10)
+ end function array_func3
+ end interface
+ logical :: x
+ if (any(array_func3(1.0).le.array_func3(2.0))) x = .true.
+end subroutine test3
+! CHECK-LABEL: func.func @_QPtest3(
+! CHECK: %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
+! CHECK: %[[VAL_3:.*]]:3 = hlfir.associate %[[VAL_2]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
+! CHECK: %[[VAL_14:.*]] = hlfir.eval_in_mem shape %{{.*}} : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+! CHECK: ^bb0(%[[VAL_15:.*]]: !fir.ref<!fir.array<10xf32>>):
+! CHECK: %[[VAL_16:.*]] = fir.call @_QParray_func3(%[[VAL_3]]#0) fastmath<contract> : (!fir.ref<f32>) -> !fir.array<10xf32>
+! CHECK: fir.save_result %[[VAL_16]] to %[[VAL_15]](%{{.*}}) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+! CHECK: }
+! CHECK: %[[VAL_17:.*]] = arith.constant 2.000000e+00 : f32
+! CHECK: %[[VAL_18:.*]]:3 = hlfir.associate %[[VAL_17]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
+! CHECK: %[[VAL_29:.*]] = hlfir.eval_in_mem shape %{{.*}} : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+! CHECK: ^bb0(%[[VAL_30:.*]]: !fir.ref<!fir.array<10xf32>>):
+! CHECK: %[[VAL_31:.*]] = fir.call @_QParray_func3(%[[VAL_18]]#0) fastmath<contract> : (!fir.ref<f32>) -> !fir.array<10xf32>
+! CHECK: fir.save_result %[[VAL_31]] to %[[VAL_30]](%{{.*}}) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+! CHECK: }
+! CHECK: %[[VAL_32:.*]] = hlfir.elemental %{{.*}} unordered : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+! CHECK: ^bb0(%[[VAL_33:.*]]: index):
+! CHECK: %[[VAL_34:.*]] = hlfir.apply %[[VAL_14]], %[[VAL_33]] : (!hlfir.expr<10xf32>, index) -> f32
+! CHECK: %[[VAL_35:.*]] = hlfir.apply %[[VAL_29]], %[[VAL_33]] : (!hlfir.expr<10xf32>, index) -> f32
+! CHECK: %[[VAL_36:.*]] = arith.cmpf ole, %[[VAL_34]], %[[VAL_35]] fastmath<contract> : f32
+! CHECK: %[[VAL_37:.*]] = fir.convert %[[VAL_36]] : (i1) -> !fir.logical<4>
+! CHECK: hlfir.yield_element %[[VAL_37]] : !fir.logical<4>
+! CHECK: }
+! CHECK: %[[VAL_38:.*]] = hlfir.any %[[VAL_32]] : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
+! CHECK: hlfir.destroy %[[VAL_32]] : !hlfir.expr<?x!fir.logical<4>>
+! CHECK: hlfir.end_associate %[[VAL_18]]#1, %[[VAL_18]]#2 : !fir.ref<f32>, i1
+! CHECK: hlfir.destroy %[[VAL_29]] : !hlfir.expr<10xf32>
+! CHECK: hlfir.end_associate %[[VAL_3]]#1, %[[VAL_3]]#2 : !fir.ref<f32>, i1
+! CHECK: hlfir.destroy %[[VAL_14]] : !hlfir.expr<10xf32>
+! CHECK: %[[VAL_39:.*]] = fir.convert %[[VAL_38]] : (!fir.logical<4>) -> i1
+! CHECK: fir.if %[[VAL_39]] {
>From 8ddc8be77b16303127470993be7333d35fb9fb56 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Tue, 6 May 2025 19:14:39 -0700
Subject: [PATCH 2/5] Added test changes missing from the original patch.
---
flang/test/Lower/HLFIR/entry_return.f90 | 8 ++++----
flang/test/Lower/HLFIR/proc-pointer-comp-nopass.f90 | 2 +-
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/flang/test/Lower/HLFIR/entry_return.f90 b/flang/test/Lower/HLFIR/entry_return.f90
index 5d3e160af2df6..18fb2b571b950 100644
--- a/flang/test/Lower/HLFIR/entry_return.f90
+++ b/flang/test/Lower/HLFIR/entry_return.f90
@@ -51,13 +51,13 @@ logical function f2()
! CHECK: %[[VAL_6:.*]]:3 = hlfir.associate %[[VAL_4]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
! CHECK: %[[VAL_7:.*]]:3 = hlfir.associate %[[VAL_5]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
! CHECK: %[[VAL_8:.*]] = fir.call @_QPcomplex(%[[VAL_6]]#0, %[[VAL_7]]#0) fastmath<contract> : (!fir.ref<f32>, !fir.ref<f32>) -> f32
-! CHECK: hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
-! CHECK: hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<f32>, i1
! CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
! CHECK: %[[VAL_10:.*]] = fir.undefined complex<f32>
! CHECK: %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_8]], [0 : index] : (complex<f32>, f32) -> complex<f32>
! CHECK: %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f32>, f32) -> complex<f32>
! CHECK: hlfir.assign %[[VAL_12]] to %[[VAL_1]]#0 : complex<f32>, !fir.ref<complex<f32>>
+! CHECK: hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
+! CHECK: hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<f32>, i1
! CHECK: %[[VAL_13:.*]] = fir.load %[[VAL_3]]#0 : !fir.ref<!fir.logical<4>>
! CHECK: return %[[VAL_13]] : !fir.logical<4>
! CHECK: }
@@ -74,13 +74,13 @@ logical function f2()
! CHECK: %[[VAL_6:.*]]:3 = hlfir.associate %[[VAL_4]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
! CHECK: %[[VAL_7:.*]]:3 = hlfir.associate %[[VAL_5]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
! CHECK: %[[VAL_8:.*]] = fir.call @_QPcomplex(%[[VAL_6]]#0, %[[VAL_7]]#0) fastmath<contract> : (!fir.ref<f32>, !fir.ref<f32>) -> f32
-! CHECK: hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
-! CHECK: hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<f32>, i1
! CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
! CHECK: %[[VAL_10:.*]] = fir.undefined complex<f32>
! CHECK: %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_8]], [0 : index] : (complex<f32>, f32) -> complex<f32>
! CHECK: %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f32>, f32) -> complex<f32>
! CHECK: hlfir.assign %[[VAL_12]] to %[[VAL_1]]#0 : complex<f32>, !fir.ref<complex<f32>>
+! CHECK: hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
+! CHECK: hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<f32>, i1
! CHECK: %[[VAL_13:.*]] = fir.load %[[VAL_1]]#0 : !fir.ref<complex<f32>>
! CHECK: return %[[VAL_13]] : complex<f32>
! CHECK: }
diff --git a/flang/test/Lower/HLFIR/proc-pointer-comp-nopass.f90 b/flang/test/Lower/HLFIR/proc-pointer-comp-nopass.f90
index 28659a33d0893..206b6e4e9b797 100644
--- a/flang/test/Lower/HLFIR/proc-pointer-comp-nopass.f90
+++ b/flang/test/Lower/HLFIR/proc-pointer-comp-nopass.f90
@@ -32,8 +32,8 @@ real function test1(x)
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.boxproc<(!fir.ref<f32>) -> f32>>
! CHECK: %[[VAL_8:.*]] = fir.box_addr %[[VAL_7]] : (!fir.boxproc<(!fir.ref<f32>) -> f32>) -> ((!fir.ref<f32>) -> f32)
! CHECK: %[[VAL_9:.*]] = fir.call %[[VAL_8]](%[[VAL_5]]#0) fastmath<contract> : (!fir.ref<f32>) -> f32
-! CHECK: hlfir.end_associate %[[VAL_5]]#1, %[[VAL_5]]#2 : !fir.ref<f32>, i1
! CHECK: hlfir.assign %[[VAL_9]] to %[[VAL_2]]#0 : f32, !fir.ref<f32>
+! CHECK: hlfir.end_associate %[[VAL_5]]#1, %[[VAL_5]]#2 : !fir.ref<f32>, i1
subroutine test2(x)
use proc_comp_defs, only : t, iface
>From d843f4eb45ad474adef13371540b2d22817074f0 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Thu, 8 May 2025 12:18:07 -0700
Subject: [PATCH 3/5] Fixed clean-ups insertion for atomic capture.
---
flang/lib/Lower/OpenMP/OpenMP.cpp | 4 +++-
flang/test/Lower/OpenMP/atomic-capture.f90 | 20 ++++++++++++++++++++
flang/test/Lower/OpenMP/atomic-update.f90 | 21 +++++++++++++++++++++
3 files changed, 44 insertions(+), 1 deletion(-)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index fcd3de9671098..d1a77a2624628 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -3129,7 +3129,9 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
}
firOpBuilder.setInsertionPointToEnd(&block);
firOpBuilder.create<mlir::omp::TerminatorOp>(loc);
- firOpBuilder.setInsertionPointToStart(&block);
+ // The clean-ups associated with the statements inside the capture
+ // construct must be generated after the AtomicCaptureOp.
+ firOpBuilder.setInsertionPointAfter(atomicCaptureOp);
}
//===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/OpenMP/atomic-capture.f90 b/flang/test/Lower/OpenMP/atomic-capture.f90
index bbb08220af9d9..b5c8edc8f31c1 100644
--- a/flang/test/Lower/OpenMP/atomic-capture.f90
+++ b/flang/test/Lower/OpenMP/atomic-capture.f90
@@ -97,3 +97,23 @@ subroutine pointers_in_atomic_capture()
b = a
!$omp end atomic
end subroutine
+
+! Check that the clean-ups associated with the function call
+! are generated after the omp.atomic.capture operation:
+! CHECK-LABEL: func.func @_QPfunc_call_cleanup(
+subroutine func_call_cleanup(x, v, vv)
+ integer :: x, v, vv
+
+! CHECK: %[[VAL_7:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK: %[[VAL_8:.*]] = fir.call @_QPfunc(%[[VAL_7]]#0) fastmath<contract> : (!fir.ref<i32>) -> f32
+! CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_8]] : (f32) -> i32
+! CHECK: omp.atomic.capture {
+! CHECK: omp.atomic.read %{{.*}} = %[[VAL_3:.*]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK: omp.atomic.write %[[VAL_3]]#0 = %[[VAL_9]] : !fir.ref<i32>, i32
+! CHECK: }
+! CHECK: hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<i32>, i1
+ !$omp atomic capture
+ v = x
+ x = func(vv + 1)
+ !$omp end atomic
+end subroutine func_call_cleanup
diff --git a/flang/test/Lower/OpenMP/atomic-update.f90 b/flang/test/Lower/OpenMP/atomic-update.f90
index 31bf447006930..e0269ea1f8af1 100644
--- a/flang/test/Lower/OpenMP/atomic-update.f90
+++ b/flang/test/Lower/OpenMP/atomic-update.f90
@@ -201,3 +201,24 @@ program OmpAtomicUpdate
!$omp atomic update
x = x + sum([ (y+2, y=1, z) ])
end program OmpAtomicUpdate
+
+! Check that the clean-ups associated with the function call
+! are generated after the omp.atomic.update operation:
+! CHECK-LABEL: func.func @_QPfunc_call_cleanup(
+subroutine func_call_cleanup(v, vv)
+ integer v, vv
+
+! CHECK: %[[VAL_6:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK: %[[VAL_7:.*]] = fir.call @_QPfunc(%[[VAL_6]]#0) fastmath<contract> : (!fir.ref<i32>) -> f32
+! CHECK: omp.atomic.update %{{.*}} : !fir.ref<i32> {
+! CHECK: ^bb0(%[[VAL_8:.*]]: i32):
+! CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_8]] : (i32) -> f32
+! CHECK: %[[VAL_10:.*]] = arith.addf %[[VAL_9]], %[[VAL_7]] fastmath<contract> : f32
+! CHECK: %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (f32) -> i32
+! CHECK: omp.yield(%[[VAL_11]] : i32)
+! CHECK: }
+! CHECK: hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<i32>, i1
+ !$omp atomic update
+ v = v + func(vv + 1)
+ !$omp end atomic
+end subroutine func_call_cleanup
>From 22c381e753d137e02ef81996c2ec65ca91f86410 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Thu, 8 May 2025 16:13:01 -0700
Subject: [PATCH 4/5] Fixed atomic capture cases with atomic update inside.
---
flang/lib/Lower/OpenMP/OpenMP.cpp | 20 +++++++---
flang/test/Lower/OpenMP/atomic-capture.f90 | 44 ++++++++++++++++++++--
2 files changed, 55 insertions(+), 9 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index d1a77a2624628..5b0b54b9e0377 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2729,7 +2729,8 @@ static void genAtomicUpdateStatement(
const parser::Expr &assignmentStmtExpr,
const parser::OmpAtomicClauseList *leftHandClauseList,
const parser::OmpAtomicClauseList *rightHandClauseList, mlir::Location loc,
- mlir::Operation *atomicCaptureOp = nullptr) {
+ mlir::Operation *atomicCaptureOp = nullptr,
+ lower::StatementContext *atomicCaptureStmtCtx = nullptr) {
// Generate `atomic.update` operation for atomic assignment statements
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Location currentLocation = converter.getCurrentLocation();
@@ -2803,15 +2804,24 @@ static void genAtomicUpdateStatement(
},
assignmentStmtExpr.u);
lower::StatementContext nonAtomicStmtCtx;
+ lower::StatementContext *stmtCtxPtr = &nonAtomicStmtCtx;
if (!nonAtomicSubExprs.empty()) {
// Generate non atomic part before all the atomic operations.
auto insertionPoint = firOpBuilder.saveInsertionPoint();
- if (atomicCaptureOp)
+ if (atomicCaptureOp) {
+ assert(atomicCaptureStmtCtx && "must specify statement context");
firOpBuilder.setInsertionPoint(atomicCaptureOp);
+ // Any clean-ups associated with the expression lowering
+ // must also be generated outside of the atomic update operation
+ // and after the atomic capture operation.
+ // The atomicCaptureStmtCtx will be finalized at the end
+ // of the atomic capture operation generation.
+ stmtCtxPtr = atomicCaptureStmtCtx;
+ }
mlir::Value nonAtomicVal;
for (auto *nonAtomicSubExpr : nonAtomicSubExprs) {
nonAtomicVal = fir::getBase(converter.genExprValue(
- currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
+ currentLocation, *nonAtomicSubExpr, *stmtCtxPtr));
exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
}
if (atomicCaptureOp)
@@ -3097,7 +3107,7 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
genAtomicUpdateStatement(
converter, stmt2LHSArg, stmt2VarType, stmt2Var, stmt2Expr,
/*leftHandClauseList=*/nullptr,
- /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
+ /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp, &stmtCtx);
} else {
// Atomic capture construct is of the form [capture-stmt, write-stmt]
firOpBuilder.setInsertionPoint(atomicCaptureOp);
@@ -3121,7 +3131,7 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
genAtomicUpdateStatement(
converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
/*leftHandClauseList=*/nullptr,
- /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
+ /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp, &stmtCtx);
genAtomicCaptureStatement(converter, stmt1LHSArg, stmt2LHSArg,
/*leftHandClauseList=*/nullptr,
/*rightHandClauseList=*/nullptr, elementType,
diff --git a/flang/test/Lower/OpenMP/atomic-capture.f90 b/flang/test/Lower/OpenMP/atomic-capture.f90
index b5c8edc8f31c1..2f800d534dc36 100644
--- a/flang/test/Lower/OpenMP/atomic-capture.f90
+++ b/flang/test/Lower/OpenMP/atomic-capture.f90
@@ -102,18 +102,54 @@ subroutine pointers_in_atomic_capture()
! are generated after the omp.atomic.capture operation:
! CHECK-LABEL: func.func @_QPfunc_call_cleanup(
subroutine func_call_cleanup(x, v, vv)
+ interface
+ integer function func(x)
+ integer :: x
+ end function func
+ end interface
integer :: x, v, vv
! CHECK: %[[VAL_7:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
-! CHECK: %[[VAL_8:.*]] = fir.call @_QPfunc(%[[VAL_7]]#0) fastmath<contract> : (!fir.ref<i32>) -> f32
-! CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_8]] : (f32) -> i32
+! CHECK: %[[VAL_8:.*]] = fir.call @_QPfunc(%[[VAL_7]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
! CHECK: omp.atomic.capture {
-! CHECK: omp.atomic.read %{{.*}} = %[[VAL_3:.*]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
-! CHECK: omp.atomic.write %[[VAL_3]]#0 = %[[VAL_9]] : !fir.ref<i32>, i32
+! CHECK: omp.atomic.read %[[VAL_1:.*]]#0 = %[[VAL_3:.*]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK: omp.atomic.write %[[VAL_3]]#0 = %[[VAL_8]] : !fir.ref<i32>, i32
! CHECK: }
! CHECK: hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<i32>, i1
!$omp atomic capture
v = x
x = func(vv + 1)
!$omp end atomic
+
+! CHECK: %[[VAL_12:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK: %[[VAL_13:.*]] = fir.call @_QPfunc(%[[VAL_12]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+! CHECK: omp.atomic.capture {
+! CHECK: omp.atomic.read %[[VAL_1]]#0 = %[[VAL_3]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK: omp.atomic.update %[[VAL_3]]#0 : !fir.ref<i32> {
+! CHECK: ^bb0(%[[VAL_14:.*]]: i32):
+! CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : i32
+! CHECK: omp.yield(%[[VAL_15]] : i32)
+! CHECK: }
+! CHECK: }
+! CHECK: hlfir.end_associate %[[VAL_12]]#1, %[[VAL_12]]#2 : !fir.ref<i32>, i1
+ !$omp atomic capture
+ v = x
+ x = func(vv + 1) + x
+ !$omp end atomic
+
+! CHECK: %[[VAL_19:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK: %[[VAL_20:.*]] = fir.call @_QPfunc(%[[VAL_19]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+! CHECK: omp.atomic.capture {
+! CHECK: omp.atomic.update %[[VAL_3]]#0 : !fir.ref<i32> {
+! CHECK: ^bb0(%[[VAL_21:.*]]: i32):
+! CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : i32
+! CHECK: omp.yield(%[[VAL_22]] : i32)
+! CHECK: }
+! CHECK: omp.atomic.read %[[VAL_1]]#0 = %[[VAL_3]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK: }
+! CHECK: hlfir.end_associate %[[VAL_19]]#1, %[[VAL_19]]#2 : !fir.ref<i32>, i1
+ !$omp atomic capture
+ x = func(vv + 1) + x
+ v = x
+ !$omp end atomic
end subroutine func_call_cleanup
>From 94df928e0031c9d26cad29841701959b7322508b Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Fri, 9 May 2025 19:33:34 -0700
Subject: [PATCH 5/5] Fixed atomic handling for OpenACC.
---
flang/lib/Lower/OpenACC.cpp | 24 ++++++--
.../test/Lower/OpenACC/acc-atomic-capture.f90 | 57 +++++++++++++++++++
.../test/Lower/OpenACC/acc-atomic-update.f90 | 18 +++++-
3 files changed, 92 insertions(+), 7 deletions(-)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 82daa05c165cb..c4529a3115996 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -414,7 +414,8 @@ static inline void genAtomicUpdateStatement(
Fortran::lower::AbstractConverter &converter, mlir::Value lhsAddr,
mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable,
const Fortran::parser::Expr &assignmentStmtExpr, mlir::Location loc,
- mlir::Operation *atomicCaptureOp = nullptr) {
+ mlir::Operation *atomicCaptureOp = nullptr,
+ Fortran::lower::StatementContext *atomicCaptureStmtCtx = nullptr) {
// Generate `atomic.update` operation for atomic assignment statements
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Location currentLocation = converter.getCurrentLocation();
@@ -494,15 +495,24 @@ static inline void genAtomicUpdateStatement(
},
assignmentStmtExpr.u);
Fortran::lower::StatementContext nonAtomicStmtCtx;
+ Fortran::lower::StatementContext *stmtCtxPtr = &nonAtomicStmtCtx;
if (!nonAtomicSubExprs.empty()) {
// Generate non atomic part before all the atomic operations.
auto insertionPoint = firOpBuilder.saveInsertionPoint();
- if (atomicCaptureOp)
+ if (atomicCaptureOp) {
+ assert(atomicCaptureStmtCtx && "must specify statement context");
firOpBuilder.setInsertionPoint(atomicCaptureOp);
+ // Any clean-ups associated with the expression lowering
+ // must also be generated outside of the atomic update operation
+ // and after the atomic capture operation.
+ // The atomicCaptureStmtCtx will be finalized at the end
+ // of the atomic capture operation generation.
+ stmtCtxPtr = atomicCaptureStmtCtx;
+ }
mlir::Value nonAtomicVal;
for (auto *nonAtomicSubExpr : nonAtomicSubExprs) {
nonAtomicVal = fir::getBase(converter.genExprValue(
- currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
+ currentLocation, *nonAtomicSubExpr, *stmtCtxPtr));
exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
}
if (atomicCaptureOp)
@@ -650,7 +660,7 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg,
elementType, loc);
genAtomicUpdateStatement(converter, stmt2LHSArg, stmt2VarType, stmt2Var,
- stmt2Expr, loc, atomicCaptureOp);
+ stmt2Expr, loc, atomicCaptureOp, &stmtCtx);
} else {
// Atomic capture construct is of the form [capture-stmt, write-stmt]
firOpBuilder.setInsertionPoint(atomicCaptureOp);
@@ -670,13 +680,15 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
*Fortran::semantics::GetExpr(stmt2Expr);
mlir::Type elementType = converter.genType(fromExpr);
genAtomicUpdateStatement(converter, stmt1LHSArg, stmt1VarType, stmt1Var,
- stmt1Expr, loc, atomicCaptureOp);
+ stmt1Expr, loc, atomicCaptureOp, &stmtCtx);
genAtomicCaptureStatement(converter, stmt1LHSArg, stmt2LHSArg, elementType,
loc);
}
firOpBuilder.setInsertionPointToEnd(&block);
firOpBuilder.create<mlir::acc::TerminatorOp>(loc);
- firOpBuilder.setInsertionPointToStart(&block);
+ // The clean-ups associated with the statements inside the capture
+ // construct must be generated after the AtomicCaptureOp.
+ firOpBuilder.setInsertionPointAfter(atomicCaptureOp);
}
template <typename Op>
diff --git a/flang/test/Lower/OpenACC/acc-atomic-capture.f90 b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
index 82059908bcd0b..ee38ab6ce826a 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-capture.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
@@ -306,3 +306,60 @@ end subroutine comp_ref_in_atomic_capture2
! CHECK: }
! CHECK: acc.atomic.read %[[V_DECL]]#0 = %[[C]] : !fir.ref<i32>, !fir.ref<i32>, i32
! CHECK: }
+
+! CHECK-LABEL: func.func @_QPatomic_capture_with_associate() {
+subroutine atomic_capture_with_associate
+ interface
+ integer function func(x)
+ integer :: x
+ end function func
+ end interface
+! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFatomic_capture_with_associateEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFatomic_capture_with_associateEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFatomic_capture_with_associateEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ integer :: x, y, z
+
+! CHECK: %[[VAL_10:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK: %[[VAL_11:.*]] = fir.call @_QPfunc(%[[VAL_10]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+! CHECK: acc.atomic.capture {
+! CHECK: acc.atomic.read %[[X_DECL]]#0 = %[[Y_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK: acc.atomic.write %[[Y_DECL]]#0 = %[[VAL_11]] : !fir.ref<i32>, i32
+! CHECK: }
+! CHECK: hlfir.end_associate %[[VAL_10]]#1, %[[VAL_10]]#2 : !fir.ref<i32>, i1
+ !$acc atomic capture
+ x = y
+ y = func(z + 1)
+ !$acc end atomic
+
+! CHECK: %[[VAL_15:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK: %[[VAL_16:.*]] = fir.call @_QPfunc(%[[VAL_15]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+! CHECK: acc.atomic.capture {
+! CHECK: acc.atomic.update %[[Y_DECL]]#0 : !fir.ref<i32> {
+! CHECK: ^bb0(%[[VAL_17:.*]]: i32):
+! CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_16]], %[[VAL_17]] : i32
+! CHECK: acc.yield %[[VAL_18]] : i32
+! CHECK: }
+! CHECK: acc.atomic.read %[[X_DECL]]#0 = %[[Y_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK: }
+! CHECK: hlfir.end_associate %[[VAL_15]]#1, %[[VAL_15]]#2 : !fir.ref<i32>, i1
+ !$acc atomic capture
+ y = func(z + 1) * y
+ x = y
+ !$acc end atomic
+
+! CHECK: %[[VAL_22:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK: %[[VAL_23:.*]] = fir.call @_QPfunc(%[[VAL_22]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+! CHECK: acc.atomic.capture {
+! CHECK: acc.atomic.read %[[X_DECL]]#0 = %[[Y_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK: acc.atomic.update %[[Y_DECL]]#0 : !fir.ref<i32> {
+! CHECK: ^bb0(%[[VAL_24:.*]]: i32):
+! CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_23]], %[[VAL_24]] : i32
+! CHECK: acc.yield %[[VAL_25]] : i32
+! CHECK: }
+! CHECK: }
+! CHECK: hlfir.end_associate %[[VAL_22]]#1, %[[VAL_22]]#2 : !fir.ref<i32>, i1
+ !$acc atomic capture
+ x = y
+ y = func(z + 1) + y
+ !$acc end atomic
+end subroutine atomic_capture_with_associate
diff --git a/flang/test/Lower/OpenACC/acc-atomic-update.f90 b/flang/test/Lower/OpenACC/acc-atomic-update.f90
index da2972877244c..71aa69fd64eba 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-update.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-update.f90
@@ -3,6 +3,11 @@
! RUN: %flang_fc1 -fopenacc -emit-hlfir %s -o - | FileCheck %s
program acc_atomic_update_test
+ interface
+ integer function func(x)
+ integer :: x
+ end function func
+ end interface
integer :: x, y, z
integer, pointer :: a, b
integer, target :: c, d
@@ -67,7 +72,18 @@ program acc_atomic_update_test
!$acc atomic
i1 = i1 + 1
!$acc end atomic
+
+!CHECK: %[[VAL_44:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+!CHECK: %[[VAL_45:.*]] = fir.call @_QPfunc(%[[VAL_44]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+!CHECK: acc.atomic.update %[[X_DECL]]#0 : !fir.ref<i32> {
+!CHECK: ^bb0(%[[VAL_46:.*]]: i32):
+!CHECK: %[[VAL_47:.*]] = arith.addi %[[VAL_46]], %[[VAL_45]] : i32
+!CHECK: acc.yield %[[VAL_47]] : i32
+!CHECK: }
+!CHECK: hlfir.end_associate %[[VAL_44]]#1, %[[VAL_44]]#2 : !fir.ref<i32>, i1
+ !$acc atomic update
+ x = x + func(z + 1)
+ !$acc end atomic
!CHECK: return
!CHECK: }
end program acc_atomic_update_test
-
More information about the flang-commits
mailing list