[flang-commits] [flang] [flang] Postpone hlfir.end_associate generation for calls. (PR #138786)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Fri May 9 19:37:51 PDT 2025


https://github.com/vzakhari updated https://github.com/llvm/llvm-project/pull/138786

>From f65c8b369ce0e866996095c239293f0716608d11 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Tue, 6 May 2025 16:38:48 -0700
Subject: [PATCH 1/5] [flang] Postpone hlfir.end_associate generation for
 calls.

If we generate hlfir.end_associate at the end of the statement,
we get easier optimizable HLFIR, because there are no compiler
generated operations with side-effects in between the call
and the consumers. This allows more hlfir.eval_in_mem to reuse
the LHS instead of allocating temporary buffer.

I do not think the same can be done for hlfir.copy_out always, e.g.:
```
subroutine test2(x)
  interface
     function array_func2(x,y)
       real:: x(*), array_func2(10), y
     end function array_func2
  end interface
  real :: x(:)
  x = array_func2(x, 1.0)
end subroutine test2
```

If we postpone the copy-out until after the assignment, then
the result may be wrong.
---
 flang/lib/Lower/ConvertCall.cpp               | 42 +++++++--
 .../Lower/HLFIR/call-postponed-associate.f90  | 85 +++++++++++++++++++
 2 files changed, 121 insertions(+), 6 deletions(-)
 create mode 100644 flang/test/Lower/HLFIR/call-postponed-associate.f90

diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index a5b85e25b1af0..d37d51f6ec634 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -960,9 +960,26 @@ struct CallCleanUp {
     mlir::Value tempVar;
     mlir::Value mustFree;
   };
-  void genCleanUp(mlir::Location loc, fir::FirOpBuilder &builder) {
-    Fortran::common::visit([&](auto &c) { c.genCleanUp(loc, builder); },
+
+  /// Generate clean-up code.
+  /// If \p postponeAssociates is true, the ExprAssociate clean-up
+  /// is not generated, and instead the corresponding CallCleanUp
+  /// object is returned as the result.
+  std::optional<CallCleanUp> genCleanUp(mlir::Location loc,
+                                        fir::FirOpBuilder &builder,
+                                        bool postponeAssociates) {
+    std::optional<CallCleanUp> postponed;
+    Fortran::common::visit(Fortran::common::visitors{
+                               [&](CopyIn &c) { c.genCleanUp(loc, builder); },
+                               [&](ExprAssociate &c) {
+                                 if (postponeAssociates)
+                                   postponed = CallCleanUp{c};
+                                 else
+                                   c.genCleanUp(loc, builder);
+                               },
+                           },
                            cleanUp);
+    return postponed;
   }
   std::variant<CopyIn, ExprAssociate> cleanUp;
 };
@@ -1729,10 +1746,23 @@ genUserCall(Fortran::lower::PreparedActualArguments &loweredActuals,
       caller, callSiteType, callContext.resultType,
       callContext.isElementalProcWithArrayArgs());
 
-  /// Clean-up associations and copy-in.
-  for (auto cleanUp : callCleanUps)
-    cleanUp.genCleanUp(loc, builder);
-
+  // Clean-up associations and copy-in.
+  // The association clean-ups are postponed to the end of the statement
+  // lowering. The copy-in clean-ups may be delayed as well,
+  // but they are done immediately after the call currently.
+  llvm::SmallVector<CallCleanUp> associateCleanups;
+  for (auto cleanUp : callCleanUps) {
+    auto postponed =
+        cleanUp.genCleanUp(loc, builder, /*postponeAssociates=*/true);
+    if (postponed)
+      associateCleanups.push_back(*postponed);
+  }
+
+  fir::FirOpBuilder *bldr = &builder;
+  callContext.stmtCtx.attachCleanup([=]() {
+    for (auto cleanUp : associateCleanups)
+      (void)cleanUp.genCleanUp(loc, *bldr, /*postponeAssociates=*/false);
+  });
   if (auto *entity = std::get_if<hlfir::EntityWithAttributes>(&loweredResult))
     return *entity;
 
diff --git a/flang/test/Lower/HLFIR/call-postponed-associate.f90 b/flang/test/Lower/HLFIR/call-postponed-associate.f90
new file mode 100644
index 0000000000000..18df62b44324b
--- /dev/null
+++ b/flang/test/Lower/HLFIR/call-postponed-associate.f90
@@ -0,0 +1,85 @@
+! RUN: bbc -emit-hlfir -o - %s -I nowhere | FileCheck %s
+
+subroutine test1
+  interface
+     function array_func1(x)
+       real:: x, array_func1(10)
+     end function array_func1
+  end interface
+  real :: x(10)
+  x = array_func1(1.0)
+end subroutine test1
+! CHECK-LABEL:   func.func @_QPtest1() {
+! CHECK:           %[[VAL_5:.*]] = arith.constant 1.000000e+00 : f32
+! CHECK:           %[[VAL_6:.*]]:3 = hlfir.associate %[[VAL_5]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
+! CHECK:           %[[VAL_17:.*]] = hlfir.eval_in_mem shape %{{.*}} : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+! CHECK:             fir.call @_QParray_func1
+! CHECK:             fir.save_result
+! CHECK:           }
+! CHECK:           hlfir.assign %[[VAL_17]] to %{{.*}} : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+! CHECK:           hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
+
+subroutine test2(x)
+  interface
+     function array_func2(x,y)
+       real:: x(*), array_func2(10), y
+     end function array_func2
+  end interface
+  real :: x(:)
+  x = array_func2(x, 1.0)
+end subroutine test2
+! CHECK-LABEL:   func.func @_QPtest2(
+! CHECK:           %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32
+! CHECK:           %[[VAL_4:.*]]:2 = hlfir.copy_in %{{.*}} to %{{.*}} : (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.box<!fir.array<?xf32>>, i1)
+! CHECK:           %[[VAL_5:.*]] = fir.box_addr %[[VAL_4]]#0 : (!fir.box<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
+! CHECK:           %[[VAL_6:.*]]:3 = hlfir.associate %[[VAL_3]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
+! CHECK:           %[[VAL_17:.*]] = hlfir.eval_in_mem shape %{{.*}} : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+! CHECK:           ^bb0(%[[VAL_18:.*]]: !fir.ref<!fir.array<10xf32>>):
+! CHECK:             %[[VAL_19:.*]] = fir.call @_QParray_func2(%[[VAL_5]], %[[VAL_6]]#0) fastmath<contract> : (!fir.ref<!fir.array<?xf32>>, !fir.ref<f32>) -> !fir.array<10xf32>
+! CHECK:             fir.save_result %[[VAL_19]] to %[[VAL_18]](%{{.*}}) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+! CHECK:           }
+! CHECK:           hlfir.copy_out %{{.*}}, %[[VAL_4]]#1 to %{{.*}} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, i1, !fir.box<!fir.array<?xf32>>) -> ()
+! CHECK:           hlfir.assign %[[VAL_17]] to %{{.*}} : !hlfir.expr<10xf32>, !fir.box<!fir.array<?xf32>>
+! CHECK:           hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
+! CHECK:           hlfir.destroy %[[VAL_17]] : !hlfir.expr<10xf32>
+
+subroutine test3(x)
+  interface
+     function array_func3(x)
+       real :: x, array_func3(10)
+     end function array_func3
+  end interface
+  logical :: x
+  if (any(array_func3(1.0).le.array_func3(2.0))) x = .true.
+end subroutine test3
+! CHECK-LABEL:   func.func @_QPtest3(
+! CHECK:           %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
+! CHECK:           %[[VAL_3:.*]]:3 = hlfir.associate %[[VAL_2]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
+! CHECK:           %[[VAL_14:.*]] = hlfir.eval_in_mem shape %{{.*}} : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+! CHECK:           ^bb0(%[[VAL_15:.*]]: !fir.ref<!fir.array<10xf32>>):
+! CHECK:             %[[VAL_16:.*]] = fir.call @_QParray_func3(%[[VAL_3]]#0) fastmath<contract> : (!fir.ref<f32>) -> !fir.array<10xf32>
+! CHECK:             fir.save_result %[[VAL_16]] to %[[VAL_15]](%{{.*}}) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+! CHECK:           }
+! CHECK:           %[[VAL_17:.*]] = arith.constant 2.000000e+00 : f32
+! CHECK:           %[[VAL_18:.*]]:3 = hlfir.associate %[[VAL_17]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
+! CHECK:           %[[VAL_29:.*]] = hlfir.eval_in_mem shape %{{.*}} : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+! CHECK:           ^bb0(%[[VAL_30:.*]]: !fir.ref<!fir.array<10xf32>>):
+! CHECK:             %[[VAL_31:.*]] = fir.call @_QParray_func3(%[[VAL_18]]#0) fastmath<contract> : (!fir.ref<f32>) -> !fir.array<10xf32>
+! CHECK:             fir.save_result %[[VAL_31]] to %[[VAL_30]](%{{.*}}) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+! CHECK:           }
+! CHECK:           %[[VAL_32:.*]] = hlfir.elemental %{{.*}} unordered : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+! CHECK:           ^bb0(%[[VAL_33:.*]]: index):
+! CHECK:             %[[VAL_34:.*]] = hlfir.apply %[[VAL_14]], %[[VAL_33]] : (!hlfir.expr<10xf32>, index) -> f32
+! CHECK:             %[[VAL_35:.*]] = hlfir.apply %[[VAL_29]], %[[VAL_33]] : (!hlfir.expr<10xf32>, index) -> f32
+! CHECK:             %[[VAL_36:.*]] = arith.cmpf ole, %[[VAL_34]], %[[VAL_35]] fastmath<contract> : f32
+! CHECK:             %[[VAL_37:.*]] = fir.convert %[[VAL_36]] : (i1) -> !fir.logical<4>
+! CHECK:             hlfir.yield_element %[[VAL_37]] : !fir.logical<4>
+! CHECK:           }
+! CHECK:           %[[VAL_38:.*]] = hlfir.any %[[VAL_32]] : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
+! CHECK:           hlfir.destroy %[[VAL_32]] : !hlfir.expr<?x!fir.logical<4>>
+! CHECK:           hlfir.end_associate %[[VAL_18]]#1, %[[VAL_18]]#2 : !fir.ref<f32>, i1
+! CHECK:           hlfir.destroy %[[VAL_29]] : !hlfir.expr<10xf32>
+! CHECK:           hlfir.end_associate %[[VAL_3]]#1, %[[VAL_3]]#2 : !fir.ref<f32>, i1
+! CHECK:           hlfir.destroy %[[VAL_14]] : !hlfir.expr<10xf32>
+! CHECK:           %[[VAL_39:.*]] = fir.convert %[[VAL_38]] : (!fir.logical<4>) -> i1
+! CHECK:           fir.if %[[VAL_39]] {

>From 8ddc8be77b16303127470993be7333d35fb9fb56 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Tue, 6 May 2025 19:14:39 -0700
Subject: [PATCH 2/5] Added test changes missing from the original patch.

---
 flang/test/Lower/HLFIR/entry_return.f90             | 8 ++++----
 flang/test/Lower/HLFIR/proc-pointer-comp-nopass.f90 | 2 +-
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/flang/test/Lower/HLFIR/entry_return.f90 b/flang/test/Lower/HLFIR/entry_return.f90
index 5d3e160af2df6..18fb2b571b950 100644
--- a/flang/test/Lower/HLFIR/entry_return.f90
+++ b/flang/test/Lower/HLFIR/entry_return.f90
@@ -51,13 +51,13 @@ logical function f2()
 ! CHECK:           %[[VAL_6:.*]]:3 = hlfir.associate %[[VAL_4]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
 ! CHECK:           %[[VAL_7:.*]]:3 = hlfir.associate %[[VAL_5]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
 ! CHECK:           %[[VAL_8:.*]] = fir.call @_QPcomplex(%[[VAL_6]]#0, %[[VAL_7]]#0) fastmath<contract> : (!fir.ref<f32>, !fir.ref<f32>) -> f32
-! CHECK:           hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
-! CHECK:           hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<f32>, i1
 ! CHECK:           %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
 ! CHECK:           %[[VAL_10:.*]] = fir.undefined complex<f32>
 ! CHECK:           %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_8]], [0 : index] : (complex<f32>, f32) -> complex<f32>
 ! CHECK:           %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f32>, f32) -> complex<f32>
 ! CHECK:           hlfir.assign %[[VAL_12]] to %[[VAL_1]]#0 : complex<f32>, !fir.ref<complex<f32>>
+! CHECK:           hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
+! CHECK:           hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<f32>, i1
 ! CHECK:           %[[VAL_13:.*]] = fir.load %[[VAL_3]]#0 : !fir.ref<!fir.logical<4>>
 ! CHECK:           return %[[VAL_13]] : !fir.logical<4>
 ! CHECK:         }
@@ -74,13 +74,13 @@ logical function f2()
 ! CHECK:           %[[VAL_6:.*]]:3 = hlfir.associate %[[VAL_4]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
 ! CHECK:           %[[VAL_7:.*]]:3 = hlfir.associate %[[VAL_5]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
 ! CHECK:           %[[VAL_8:.*]] = fir.call @_QPcomplex(%[[VAL_6]]#0, %[[VAL_7]]#0) fastmath<contract> : (!fir.ref<f32>, !fir.ref<f32>) -> f32
-! CHECK:           hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
-! CHECK:           hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<f32>, i1
 ! CHECK:           %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
 ! CHECK:           %[[VAL_10:.*]] = fir.undefined complex<f32>
 ! CHECK:           %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_8]], [0 : index] : (complex<f32>, f32) -> complex<f32>
 ! CHECK:           %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f32>, f32) -> complex<f32>
 ! CHECK:           hlfir.assign %[[VAL_12]] to %[[VAL_1]]#0 : complex<f32>, !fir.ref<complex<f32>>
+! CHECK:           hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
+! CHECK:           hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<f32>, i1
 ! CHECK:           %[[VAL_13:.*]] = fir.load %[[VAL_1]]#0 : !fir.ref<complex<f32>>
 ! CHECK:           return %[[VAL_13]] : complex<f32>
 ! CHECK:         }
diff --git a/flang/test/Lower/HLFIR/proc-pointer-comp-nopass.f90 b/flang/test/Lower/HLFIR/proc-pointer-comp-nopass.f90
index 28659a33d0893..206b6e4e9b797 100644
--- a/flang/test/Lower/HLFIR/proc-pointer-comp-nopass.f90
+++ b/flang/test/Lower/HLFIR/proc-pointer-comp-nopass.f90
@@ -32,8 +32,8 @@ real function test1(x)
 ! CHECK:           %[[VAL_7:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.boxproc<(!fir.ref<f32>) -> f32>>
 ! CHECK:           %[[VAL_8:.*]] = fir.box_addr %[[VAL_7]] : (!fir.boxproc<(!fir.ref<f32>) -> f32>) -> ((!fir.ref<f32>) -> f32)
 ! CHECK:           %[[VAL_9:.*]] = fir.call %[[VAL_8]](%[[VAL_5]]#0) fastmath<contract> : (!fir.ref<f32>) -> f32
-! CHECK:           hlfir.end_associate %[[VAL_5]]#1, %[[VAL_5]]#2 : !fir.ref<f32>, i1
 ! CHECK:           hlfir.assign %[[VAL_9]] to %[[VAL_2]]#0 : f32, !fir.ref<f32>
+! CHECK:           hlfir.end_associate %[[VAL_5]]#1, %[[VAL_5]]#2 : !fir.ref<f32>, i1
 
 subroutine test2(x)
   use proc_comp_defs, only : t, iface

>From d843f4eb45ad474adef13371540b2d22817074f0 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Thu, 8 May 2025 12:18:07 -0700
Subject: [PATCH 3/5] Fixed clean-ups insertion for atomic capture.

---
 flang/lib/Lower/OpenMP/OpenMP.cpp          |  4 +++-
 flang/test/Lower/OpenMP/atomic-capture.f90 | 20 ++++++++++++++++++++
 flang/test/Lower/OpenMP/atomic-update.f90  | 21 +++++++++++++++++++++
 3 files changed, 44 insertions(+), 1 deletion(-)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index fcd3de9671098..d1a77a2624628 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -3129,7 +3129,9 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
   }
   firOpBuilder.setInsertionPointToEnd(&block);
   firOpBuilder.create<mlir::omp::TerminatorOp>(loc);
-  firOpBuilder.setInsertionPointToStart(&block);
+  // The clean-ups associated with the statements inside the capture
+  // construct must be generated after the AtomicCaptureOp.
+  firOpBuilder.setInsertionPointAfter(atomicCaptureOp);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/OpenMP/atomic-capture.f90 b/flang/test/Lower/OpenMP/atomic-capture.f90
index bbb08220af9d9..b5c8edc8f31c1 100644
--- a/flang/test/Lower/OpenMP/atomic-capture.f90
+++ b/flang/test/Lower/OpenMP/atomic-capture.f90
@@ -97,3 +97,23 @@ subroutine pointers_in_atomic_capture()
         b = a
     !$omp end atomic
 end subroutine
+
+! Check that the clean-ups associated with the function call
+! are generated after the omp.atomic.capture operation:
+! CHECK-LABEL:   func.func @_QPfunc_call_cleanup(
+subroutine func_call_cleanup(x, v, vv)
+  integer :: x, v, vv
+
+! CHECK:           %[[VAL_7:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK:           %[[VAL_8:.*]] = fir.call @_QPfunc(%[[VAL_7]]#0) fastmath<contract> : (!fir.ref<i32>) -> f32
+! CHECK:           %[[VAL_9:.*]] = fir.convert %[[VAL_8]] : (f32) -> i32
+! CHECK:           omp.atomic.capture {
+! CHECK:             omp.atomic.read %{{.*}} = %[[VAL_3:.*]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:             omp.atomic.write %[[VAL_3]]#0 = %[[VAL_9]] : !fir.ref<i32>, i32
+! CHECK:           }
+! CHECK:           hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<i32>, i1
+  !$omp atomic capture
+  v = x
+  x = func(vv + 1)
+  !$omp end atomic
+end subroutine func_call_cleanup
diff --git a/flang/test/Lower/OpenMP/atomic-update.f90 b/flang/test/Lower/OpenMP/atomic-update.f90
index 31bf447006930..e0269ea1f8af1 100644
--- a/flang/test/Lower/OpenMP/atomic-update.f90
+++ b/flang/test/Lower/OpenMP/atomic-update.f90
@@ -201,3 +201,24 @@ program OmpAtomicUpdate
   !$omp atomic update
     x = x + sum([ (y+2, y=1, z) ])
 end program OmpAtomicUpdate
+
+! Check that the clean-ups associated with the function call
+! are generated after the omp.atomic.update operation:
+! CHECK-LABEL:   func.func @_QPfunc_call_cleanup(
+subroutine func_call_cleanup(v, vv)
+  integer v, vv
+
+! CHECK:           %[[VAL_6:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK:           %[[VAL_7:.*]] = fir.call @_QPfunc(%[[VAL_6]]#0) fastmath<contract> : (!fir.ref<i32>) -> f32
+! CHECK:           omp.atomic.update %{{.*}} : !fir.ref<i32> {
+! CHECK:           ^bb0(%[[VAL_8:.*]]: i32):
+! CHECK:             %[[VAL_9:.*]] = fir.convert %[[VAL_8]] : (i32) -> f32
+! CHECK:             %[[VAL_10:.*]] = arith.addf %[[VAL_9]], %[[VAL_7]] fastmath<contract> : f32
+! CHECK:             %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (f32) -> i32
+! CHECK:             omp.yield(%[[VAL_11]] : i32)
+! CHECK:           }
+! CHECK:           hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<i32>, i1
+  !$omp atomic update
+  v = v + func(vv + 1)
+  !$omp end atomic
+end subroutine func_call_cleanup

>From 22c381e753d137e02ef81996c2ec65ca91f86410 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Thu, 8 May 2025 16:13:01 -0700
Subject: [PATCH 4/5] Fixed atomic capture cases with atomic update inside.

---
 flang/lib/Lower/OpenMP/OpenMP.cpp          | 20 +++++++---
 flang/test/Lower/OpenMP/atomic-capture.f90 | 44 ++++++++++++++++++++--
 2 files changed, 55 insertions(+), 9 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index d1a77a2624628..5b0b54b9e0377 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2729,7 +2729,8 @@ static void genAtomicUpdateStatement(
     const parser::Expr &assignmentStmtExpr,
     const parser::OmpAtomicClauseList *leftHandClauseList,
     const parser::OmpAtomicClauseList *rightHandClauseList, mlir::Location loc,
-    mlir::Operation *atomicCaptureOp = nullptr) {
+    mlir::Operation *atomicCaptureOp = nullptr,
+    lower::StatementContext *atomicCaptureStmtCtx = nullptr) {
   // Generate `atomic.update` operation for atomic assignment statements
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::Location currentLocation = converter.getCurrentLocation();
@@ -2803,15 +2804,24 @@ static void genAtomicUpdateStatement(
       },
       assignmentStmtExpr.u);
   lower::StatementContext nonAtomicStmtCtx;
+  lower::StatementContext *stmtCtxPtr = &nonAtomicStmtCtx;
   if (!nonAtomicSubExprs.empty()) {
     // Generate non atomic part before all the atomic operations.
     auto insertionPoint = firOpBuilder.saveInsertionPoint();
-    if (atomicCaptureOp)
+    if (atomicCaptureOp) {
+      assert(atomicCaptureStmtCtx && "must specify statement context");
       firOpBuilder.setInsertionPoint(atomicCaptureOp);
+      // Any clean-ups associated with the expression lowering
+      // must also be generated outside of the atomic update operation
+      // and after the atomic capture operation.
+      // The atomicCaptureStmtCtx will be finalized at the end
+      // of the atomic capture operation generation.
+      stmtCtxPtr = atomicCaptureStmtCtx;
+    }
     mlir::Value nonAtomicVal;
     for (auto *nonAtomicSubExpr : nonAtomicSubExprs) {
       nonAtomicVal = fir::getBase(converter.genExprValue(
-          currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
+          currentLocation, *nonAtomicSubExpr, *stmtCtxPtr));
       exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
     }
     if (atomicCaptureOp)
@@ -3097,7 +3107,7 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
       genAtomicUpdateStatement(
           converter, stmt2LHSArg, stmt2VarType, stmt2Var, stmt2Expr,
           /*leftHandClauseList=*/nullptr,
-          /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
+          /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp, &stmtCtx);
     } else {
       // Atomic capture construct is of the form [capture-stmt, write-stmt]
       firOpBuilder.setInsertionPoint(atomicCaptureOp);
@@ -3121,7 +3131,7 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
     genAtomicUpdateStatement(
         converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
         /*leftHandClauseList=*/nullptr,
-        /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
+        /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp, &stmtCtx);
     genAtomicCaptureStatement(converter, stmt1LHSArg, stmt2LHSArg,
                               /*leftHandClauseList=*/nullptr,
                               /*rightHandClauseList=*/nullptr, elementType,
diff --git a/flang/test/Lower/OpenMP/atomic-capture.f90 b/flang/test/Lower/OpenMP/atomic-capture.f90
index b5c8edc8f31c1..2f800d534dc36 100644
--- a/flang/test/Lower/OpenMP/atomic-capture.f90
+++ b/flang/test/Lower/OpenMP/atomic-capture.f90
@@ -102,18 +102,54 @@ subroutine pointers_in_atomic_capture()
 ! are generated after the omp.atomic.capture operation:
 ! CHECK-LABEL:   func.func @_QPfunc_call_cleanup(
 subroutine func_call_cleanup(x, v, vv)
+  interface
+     integer function func(x)
+       integer :: x
+     end function func
+  end interface
   integer :: x, v, vv
 
 ! CHECK:           %[[VAL_7:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
-! CHECK:           %[[VAL_8:.*]] = fir.call @_QPfunc(%[[VAL_7]]#0) fastmath<contract> : (!fir.ref<i32>) -> f32
-! CHECK:           %[[VAL_9:.*]] = fir.convert %[[VAL_8]] : (f32) -> i32
+! CHECK:           %[[VAL_8:.*]] = fir.call @_QPfunc(%[[VAL_7]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
 ! CHECK:           omp.atomic.capture {
-! CHECK:             omp.atomic.read %{{.*}} = %[[VAL_3:.*]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
-! CHECK:             omp.atomic.write %[[VAL_3]]#0 = %[[VAL_9]] : !fir.ref<i32>, i32
+! CHECK:             omp.atomic.read %[[VAL_1:.*]]#0 = %[[VAL_3:.*]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:             omp.atomic.write %[[VAL_3]]#0 = %[[VAL_8]] : !fir.ref<i32>, i32
 ! CHECK:           }
 ! CHECK:           hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<i32>, i1
   !$omp atomic capture
   v = x
   x = func(vv + 1)
   !$omp end atomic
+
+! CHECK:           %[[VAL_12:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK:           %[[VAL_13:.*]] = fir.call @_QPfunc(%[[VAL_12]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+! CHECK:           omp.atomic.capture {
+! CHECK:             omp.atomic.read %[[VAL_1]]#0 = %[[VAL_3]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:             omp.atomic.update %[[VAL_3]]#0 : !fir.ref<i32> {
+! CHECK:             ^bb0(%[[VAL_14:.*]]: i32):
+! CHECK:               %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : i32
+! CHECK:               omp.yield(%[[VAL_15]] : i32)
+! CHECK:             }
+! CHECK:           }
+! CHECK:           hlfir.end_associate %[[VAL_12]]#1, %[[VAL_12]]#2 : !fir.ref<i32>, i1
+  !$omp atomic capture
+  v = x
+  x = func(vv + 1) + x
+  !$omp end atomic
+
+! CHECK:           %[[VAL_19:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK:           %[[VAL_20:.*]] = fir.call @_QPfunc(%[[VAL_19]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+! CHECK:           omp.atomic.capture {
+! CHECK:             omp.atomic.update %[[VAL_3]]#0 : !fir.ref<i32> {
+! CHECK:             ^bb0(%[[VAL_21:.*]]: i32):
+! CHECK:               %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : i32
+! CHECK:               omp.yield(%[[VAL_22]] : i32)
+! CHECK:             }
+! CHECK:             omp.atomic.read %[[VAL_1]]#0 = %[[VAL_3]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:           }
+! CHECK:           hlfir.end_associate %[[VAL_19]]#1, %[[VAL_19]]#2 : !fir.ref<i32>, i1
+  !$omp atomic capture
+  x = func(vv + 1) + x
+  v = x
+  !$omp end atomic
 end subroutine func_call_cleanup

>From 94df928e0031c9d26cad29841701959b7322508b Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Fri, 9 May 2025 19:33:34 -0700
Subject: [PATCH 5/5] Fixed atomic handling for OpenACC.

---
 flang/lib/Lower/OpenACC.cpp                   | 24 ++++++--
 .../test/Lower/OpenACC/acc-atomic-capture.f90 | 57 +++++++++++++++++++
 .../test/Lower/OpenACC/acc-atomic-update.f90  | 18 +++++-
 3 files changed, 92 insertions(+), 7 deletions(-)

diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 82daa05c165cb..c4529a3115996 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -414,7 +414,8 @@ static inline void genAtomicUpdateStatement(
     Fortran::lower::AbstractConverter &converter, mlir::Value lhsAddr,
     mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable,
     const Fortran::parser::Expr &assignmentStmtExpr, mlir::Location loc,
-    mlir::Operation *atomicCaptureOp = nullptr) {
+    mlir::Operation *atomicCaptureOp = nullptr,
+    Fortran::lower::StatementContext *atomicCaptureStmtCtx = nullptr) {
   // Generate `atomic.update` operation for atomic assignment statements
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::Location currentLocation = converter.getCurrentLocation();
@@ -494,15 +495,24 @@ static inline void genAtomicUpdateStatement(
       },
       assignmentStmtExpr.u);
   Fortran::lower::StatementContext nonAtomicStmtCtx;
+  Fortran::lower::StatementContext *stmtCtxPtr = &nonAtomicStmtCtx;
   if (!nonAtomicSubExprs.empty()) {
     // Generate non atomic part before all the atomic operations.
     auto insertionPoint = firOpBuilder.saveInsertionPoint();
-    if (atomicCaptureOp)
+    if (atomicCaptureOp) {
+      assert(atomicCaptureStmtCtx && "must specify statement context");
       firOpBuilder.setInsertionPoint(atomicCaptureOp);
+      // Any clean-ups associated with the expression lowering
+      // must also be generated outside of the atomic update operation
+      // and after the atomic capture operation.
+      // The atomicCaptureStmtCtx will be finalized at the end
+      // of the atomic capture operation generation.
+      stmtCtxPtr = atomicCaptureStmtCtx;
+    }
     mlir::Value nonAtomicVal;
     for (auto *nonAtomicSubExpr : nonAtomicSubExprs) {
       nonAtomicVal = fir::getBase(converter.genExprValue(
-          currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
+          currentLocation, *nonAtomicSubExpr, *stmtCtxPtr));
       exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
     }
     if (atomicCaptureOp)
@@ -650,7 +660,7 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
       genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg,
                                 elementType, loc);
       genAtomicUpdateStatement(converter, stmt2LHSArg, stmt2VarType, stmt2Var,
-                               stmt2Expr, loc, atomicCaptureOp);
+                               stmt2Expr, loc, atomicCaptureOp, &stmtCtx);
     } else {
       // Atomic capture construct is of the form [capture-stmt, write-stmt]
       firOpBuilder.setInsertionPoint(atomicCaptureOp);
@@ -670,13 +680,15 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
         *Fortran::semantics::GetExpr(stmt2Expr);
     mlir::Type elementType = converter.genType(fromExpr);
     genAtomicUpdateStatement(converter, stmt1LHSArg, stmt1VarType, stmt1Var,
-                             stmt1Expr, loc, atomicCaptureOp);
+                             stmt1Expr, loc, atomicCaptureOp, &stmtCtx);
     genAtomicCaptureStatement(converter, stmt1LHSArg, stmt2LHSArg, elementType,
                               loc);
   }
   firOpBuilder.setInsertionPointToEnd(&block);
   firOpBuilder.create<mlir::acc::TerminatorOp>(loc);
-  firOpBuilder.setInsertionPointToStart(&block);
+  // The clean-ups associated with the statements inside the capture
+  // construct must be generated after the AtomicCaptureOp.
+  firOpBuilder.setInsertionPointAfter(atomicCaptureOp);
 }
 
 template <typename Op>
diff --git a/flang/test/Lower/OpenACC/acc-atomic-capture.f90 b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
index 82059908bcd0b..ee38ab6ce826a 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-capture.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
@@ -306,3 +306,60 @@ end subroutine comp_ref_in_atomic_capture2
 ! CHECK:             }
 ! CHECK:             acc.atomic.read %[[V_DECL]]#0 = %[[C]] : !fir.ref<i32>, !fir.ref<i32>, i32
 ! CHECK:           }
+
+! CHECK-LABEL:   func.func @_QPatomic_capture_with_associate() {
+subroutine atomic_capture_with_associate
+  interface
+     integer function func(x)
+       integer :: x
+     end function func
+  end interface
+! CHECK:           %[[X_DECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFatomic_capture_with_associateEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK:           %[[Y_DECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFatomic_capture_with_associateEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK:           %[[Z_DECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFatomic_capture_with_associateEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  integer :: x, y, z
+
+! CHECK:           %[[VAL_10:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK:           %[[VAL_11:.*]] = fir.call @_QPfunc(%[[VAL_10]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+! CHECK:           acc.atomic.capture {
+! CHECK:             acc.atomic.read %[[X_DECL]]#0 = %[[Y_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:             acc.atomic.write %[[Y_DECL]]#0 = %[[VAL_11]] : !fir.ref<i32>, i32
+! CHECK:           }
+! CHECK:           hlfir.end_associate %[[VAL_10]]#1, %[[VAL_10]]#2 : !fir.ref<i32>, i1
+  !$acc atomic capture
+  x = y
+  y = func(z + 1)
+  !$acc end atomic
+
+! CHECK:           %[[VAL_15:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK:           %[[VAL_16:.*]] = fir.call @_QPfunc(%[[VAL_15]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+! CHECK:           acc.atomic.capture {
+! CHECK:             acc.atomic.update %[[Y_DECL]]#0 : !fir.ref<i32> {
+! CHECK:             ^bb0(%[[VAL_17:.*]]: i32):
+! CHECK:               %[[VAL_18:.*]] = arith.muli %[[VAL_16]], %[[VAL_17]] : i32
+! CHECK:               acc.yield %[[VAL_18]] : i32
+! CHECK:             }
+! CHECK:             acc.atomic.read %[[X_DECL]]#0 = %[[Y_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:           }
+! CHECK:           hlfir.end_associate %[[VAL_15]]#1, %[[VAL_15]]#2 : !fir.ref<i32>, i1
+  !$acc atomic capture
+  y = func(z + 1) * y
+  x = y
+  !$acc end atomic
+
+! CHECK:           %[[VAL_22:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK:           %[[VAL_23:.*]] = fir.call @_QPfunc(%[[VAL_22]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+! CHECK:           acc.atomic.capture {
+! CHECK:             acc.atomic.read %[[X_DECL]]#0 = %[[Y_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:             acc.atomic.update %[[Y_DECL]]#0 : !fir.ref<i32> {
+! CHECK:             ^bb0(%[[VAL_24:.*]]: i32):
+! CHECK:               %[[VAL_25:.*]] = arith.addi %[[VAL_23]], %[[VAL_24]] : i32
+! CHECK:               acc.yield %[[VAL_25]] : i32
+! CHECK:             }
+! CHECK:           }
+! CHECK:           hlfir.end_associate %[[VAL_22]]#1, %[[VAL_22]]#2 : !fir.ref<i32>, i1
+  !$acc atomic capture
+  x = y
+  y = func(z + 1) + y
+  !$acc end atomic
+end subroutine atomic_capture_with_associate
diff --git a/flang/test/Lower/OpenACC/acc-atomic-update.f90 b/flang/test/Lower/OpenACC/acc-atomic-update.f90
index da2972877244c..71aa69fd64eba 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-update.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-update.f90
@@ -3,6 +3,11 @@
 ! RUN: %flang_fc1 -fopenacc -emit-hlfir %s -o - | FileCheck %s
 
 program acc_atomic_update_test
+    interface
+       integer function func(x)
+         integer :: x
+       end function func
+    end interface
     integer :: x, y, z
     integer, pointer :: a, b
     integer, target :: c, d
@@ -67,7 +72,18 @@ program acc_atomic_update_test
     !$acc atomic
       i1 = i1 + 1
     !$acc end atomic
+
+!CHECK:  %[[VAL_44:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+!CHECK:  %[[VAL_45:.*]] = fir.call @_QPfunc(%[[VAL_44]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+!CHECK:  acc.atomic.update %[[X_DECL]]#0 : !fir.ref<i32> {
+!CHECK:  ^bb0(%[[VAL_46:.*]]: i32):
+!CHECK:    %[[VAL_47:.*]] = arith.addi %[[VAL_46]], %[[VAL_45]] : i32
+!CHECK:    acc.yield %[[VAL_47]] : i32
+!CHECK:  }
+!CHECK:  hlfir.end_associate %[[VAL_44]]#1, %[[VAL_44]]#2 : !fir.ref<i32>, i1
+    !$acc atomic update
+    x = x + func(z + 1)
+    !$acc end atomic
 !CHECK:  return
 !CHECK: }
 end program acc_atomic_update_test
-



More information about the flang-commits mailing list