[flang-commits] [flang] 09b772e - [flang] Postpone hlfir.end_associate generation for calls. (#138786)

via flang-commits flang-commits at lists.llvm.org
Mon May 12 14:03:19 PDT 2025


Author: Slava Zakharin
Date: 2025-05-12T14:03:15-07:00
New Revision: 09b772e2efad804fdda02e2bd9ee44a2aaaddeeb

URL: https://github.com/llvm/llvm-project/commit/09b772e2efad804fdda02e2bd9ee44a2aaaddeeb
DIFF: https://github.com/llvm/llvm-project/commit/09b772e2efad804fdda02e2bd9ee44a2aaaddeeb.diff

LOG: [flang] Postpone hlfir.end_associate generation for calls. (#138786)

If we generate hlfir.end_associate at the end of the statement,
we get easier optimizable HLFIR, because there are no compiler
generated operations with side-effects in between the call
and the consumers. This allows more hlfir.eval_in_mem to reuse
the LHS instead of allocating temporary buffer.

I do not think the same can be done for hlfir.copy_out always, e.g.:
```
subroutine test2(x)
  interface
     function array_func2(x,y)
       real:: x(*), array_func2(10), y
     end function array_func2
  end interface
  real :: x(:)
  x = array_func2(x, 1.0)
end subroutine test2
```

If we postpone the copy-out until after the assignment, then
the result may be wrong.

Added: 
    flang/test/Lower/HLFIR/call-postponed-associate.f90

Modified: 
    flang/lib/Lower/ConvertCall.cpp
    flang/lib/Lower/OpenACC.cpp
    flang/lib/Lower/OpenMP/OpenMP.cpp
    flang/test/Lower/HLFIR/entry_return.f90
    flang/test/Lower/HLFIR/proc-pointer-comp-nopass.f90
    flang/test/Lower/OpenACC/acc-atomic-capture.f90
    flang/test/Lower/OpenACC/acc-atomic-update.f90
    flang/test/Lower/OpenMP/atomic-capture.f90
    flang/test/Lower/OpenMP/atomic-update.f90

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index a5b85e25b1af0..d37d51f6ec634 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -960,9 +960,26 @@ struct CallCleanUp {
     mlir::Value tempVar;
     mlir::Value mustFree;
   };
-  void genCleanUp(mlir::Location loc, fir::FirOpBuilder &builder) {
-    Fortran::common::visit([&](auto &c) { c.genCleanUp(loc, builder); },
+
+  /// Generate clean-up code.
+  /// If \p postponeAssociates is true, the ExprAssociate clean-up
+  /// is not generated, and instead the corresponding CallCleanUp
+  /// object is returned as the result.
+  std::optional<CallCleanUp> genCleanUp(mlir::Location loc,
+                                        fir::FirOpBuilder &builder,
+                                        bool postponeAssociates) {
+    std::optional<CallCleanUp> postponed;
+    Fortran::common::visit(Fortran::common::visitors{
+                               [&](CopyIn &c) { c.genCleanUp(loc, builder); },
+                               [&](ExprAssociate &c) {
+                                 if (postponeAssociates)
+                                   postponed = CallCleanUp{c};
+                                 else
+                                   c.genCleanUp(loc, builder);
+                               },
+                           },
                            cleanUp);
+    return postponed;
   }
   std::variant<CopyIn, ExprAssociate> cleanUp;
 };
@@ -1729,10 +1746,23 @@ genUserCall(Fortran::lower::PreparedActualArguments &loweredActuals,
       caller, callSiteType, callContext.resultType,
       callContext.isElementalProcWithArrayArgs());
 
-  /// Clean-up associations and copy-in.
-  for (auto cleanUp : callCleanUps)
-    cleanUp.genCleanUp(loc, builder);
-
+  // Clean-up associations and copy-in.
+  // The association clean-ups are postponed to the end of the statement
+  // lowering. The copy-in clean-ups may be delayed as well,
+  // but they are done immediately after the call currently.
+  llvm::SmallVector<CallCleanUp> associateCleanups;
+  for (auto cleanUp : callCleanUps) {
+    auto postponed =
+        cleanUp.genCleanUp(loc, builder, /*postponeAssociates=*/true);
+    if (postponed)
+      associateCleanups.push_back(*postponed);
+  }
+
+  fir::FirOpBuilder *bldr = &builder;
+  callContext.stmtCtx.attachCleanup([=]() {
+    for (auto cleanUp : associateCleanups)
+      (void)cleanUp.genCleanUp(loc, *bldr, /*postponeAssociates=*/false);
+  });
   if (auto *entity = std::get_if<hlfir::EntityWithAttributes>(&loweredResult))
     return *entity;
 

diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 2f70041a04dde..e1918288d6de3 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -416,7 +416,8 @@ static inline void genAtomicUpdateStatement(
     Fortran::lower::AbstractConverter &converter, mlir::Value lhsAddr,
     mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable,
     const Fortran::parser::Expr &assignmentStmtExpr, mlir::Location loc,
-    mlir::Operation *atomicCaptureOp = nullptr) {
+    mlir::Operation *atomicCaptureOp = nullptr,
+    Fortran::lower::StatementContext *atomicCaptureStmtCtx = nullptr) {
   // Generate `atomic.update` operation for atomic assignment statements
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::Location currentLocation = converter.getCurrentLocation();
@@ -496,15 +497,24 @@ static inline void genAtomicUpdateStatement(
       },
       assignmentStmtExpr.u);
   Fortran::lower::StatementContext nonAtomicStmtCtx;
+  Fortran::lower::StatementContext *stmtCtxPtr = &nonAtomicStmtCtx;
   if (!nonAtomicSubExprs.empty()) {
     // Generate non atomic part before all the atomic operations.
     auto insertionPoint = firOpBuilder.saveInsertionPoint();
-    if (atomicCaptureOp)
+    if (atomicCaptureOp) {
+      assert(atomicCaptureStmtCtx && "must specify statement context");
       firOpBuilder.setInsertionPoint(atomicCaptureOp);
+      // Any clean-ups associated with the expression lowering
+      // must also be generated outside of the atomic update operation
+      // and after the atomic capture operation.
+      // The atomicCaptureStmtCtx will be finalized at the end
+      // of the atomic capture operation generation.
+      stmtCtxPtr = atomicCaptureStmtCtx;
+    }
     mlir::Value nonAtomicVal;
     for (auto *nonAtomicSubExpr : nonAtomicSubExprs) {
       nonAtomicVal = fir::getBase(converter.genExprValue(
-          currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
+          currentLocation, *nonAtomicSubExpr, *stmtCtxPtr));
       exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
     }
     if (atomicCaptureOp)
@@ -652,7 +662,7 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
       genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg,
                                 elementType, loc);
       genAtomicUpdateStatement(converter, stmt2LHSArg, stmt2VarType, stmt2Var,
-                               stmt2Expr, loc, atomicCaptureOp);
+                               stmt2Expr, loc, atomicCaptureOp, &stmtCtx);
     } else {
       // Atomic capture construct is of the form [capture-stmt, write-stmt]
       firOpBuilder.setInsertionPoint(atomicCaptureOp);
@@ -672,13 +682,15 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
         *Fortran::semantics::GetExpr(stmt2Expr);
     mlir::Type elementType = converter.genType(fromExpr);
     genAtomicUpdateStatement(converter, stmt1LHSArg, stmt1VarType, stmt1Var,
-                             stmt1Expr, loc, atomicCaptureOp);
+                             stmt1Expr, loc, atomicCaptureOp, &stmtCtx);
     genAtomicCaptureStatement(converter, stmt1LHSArg, stmt2LHSArg, elementType,
                               loc);
   }
   firOpBuilder.setInsertionPointToEnd(&block);
   firOpBuilder.create<mlir::acc::TerminatorOp>(loc);
-  firOpBuilder.setInsertionPointToStart(&block);
+  // The clean-ups associated with the statements inside the capture
+  // construct must be generated after the AtomicCaptureOp.
+  firOpBuilder.setInsertionPointAfter(atomicCaptureOp);
 }
 
 template <typename Op>

diff  --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 446aa2deb3d05..4909c3e277a07 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2816,7 +2816,8 @@ static void genAtomicUpdateStatement(
     const parser::Expr &assignmentStmtExpr,
     const parser::OmpAtomicClauseList *leftHandClauseList,
     const parser::OmpAtomicClauseList *rightHandClauseList, mlir::Location loc,
-    mlir::Operation *atomicCaptureOp = nullptr) {
+    mlir::Operation *atomicCaptureOp = nullptr,
+    lower::StatementContext *atomicCaptureStmtCtx = nullptr) {
   // Generate `atomic.update` operation for atomic assignment statements
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::Location currentLocation = converter.getCurrentLocation();
@@ -2890,15 +2891,24 @@ static void genAtomicUpdateStatement(
       },
       assignmentStmtExpr.u);
   lower::StatementContext nonAtomicStmtCtx;
+  lower::StatementContext *stmtCtxPtr = &nonAtomicStmtCtx;
   if (!nonAtomicSubExprs.empty()) {
     // Generate non atomic part before all the atomic operations.
     auto insertionPoint = firOpBuilder.saveInsertionPoint();
-    if (atomicCaptureOp)
+    if (atomicCaptureOp) {
+      assert(atomicCaptureStmtCtx && "must specify statement context");
       firOpBuilder.setInsertionPoint(atomicCaptureOp);
+      // Any clean-ups associated with the expression lowering
+      // must also be generated outside of the atomic update operation
+      // and after the atomic capture operation.
+      // The atomicCaptureStmtCtx will be finalized at the end
+      // of the atomic capture operation generation.
+      stmtCtxPtr = atomicCaptureStmtCtx;
+    }
     mlir::Value nonAtomicVal;
     for (auto *nonAtomicSubExpr : nonAtomicSubExprs) {
       nonAtomicVal = fir::getBase(converter.genExprValue(
-          currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
+          currentLocation, *nonAtomicSubExpr, *stmtCtxPtr));
       exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
     }
     if (atomicCaptureOp)
@@ -3238,7 +3248,7 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
       genAtomicUpdateStatement(
           converter, stmt2LHSArg, stmt2VarType, stmt2Var, stmt2Expr,
           /*leftHandClauseList=*/nullptr,
-          /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
+          /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp, &stmtCtx);
     } else {
       // Atomic capture construct is of the form [capture-stmt, write-stmt]
       firOpBuilder.setInsertionPoint(atomicCaptureOp);
@@ -3284,7 +3294,7 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
     genAtomicUpdateStatement(
         converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
         /*leftHandClauseList=*/nullptr,
-        /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
+        /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp, &stmtCtx);
 
     if (stmt1VarType != stmt2VarType) {
       mlir::Value alloca;
@@ -3316,7 +3326,9 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
   }
   firOpBuilder.setInsertionPointToEnd(&block);
   firOpBuilder.create<mlir::omp::TerminatorOp>(loc);
-  firOpBuilder.setInsertionPointToStart(&block);
+  // The clean-ups associated with the statements inside the capture
+  // construct must be generated after the AtomicCaptureOp.
+  firOpBuilder.setInsertionPointAfter(atomicCaptureOp);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/flang/test/Lower/HLFIR/call-postponed-associate.f90 b/flang/test/Lower/HLFIR/call-postponed-associate.f90
new file mode 100644
index 0000000000000..18df62b44324b
--- /dev/null
+++ b/flang/test/Lower/HLFIR/call-postponed-associate.f90
@@ -0,0 +1,85 @@
+! RUN: bbc -emit-hlfir -o - %s -I nowhere | FileCheck %s
+
+subroutine test1
+  interface
+     function array_func1(x)
+       real:: x, array_func1(10)
+     end function array_func1
+  end interface
+  real :: x(10)
+  x = array_func1(1.0)
+end subroutine test1
+! CHECK-LABEL:   func.func @_QPtest1() {
+! CHECK:           %[[VAL_5:.*]] = arith.constant 1.000000e+00 : f32
+! CHECK:           %[[VAL_6:.*]]:3 = hlfir.associate %[[VAL_5]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
+! CHECK:           %[[VAL_17:.*]] = hlfir.eval_in_mem shape %{{.*}} : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+! CHECK:             fir.call @_QParray_func1
+! CHECK:             fir.save_result
+! CHECK:           }
+! CHECK:           hlfir.assign %[[VAL_17]] to %{{.*}} : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+! CHECK:           hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
+
+subroutine test2(x)
+  interface
+     function array_func2(x,y)
+       real:: x(*), array_func2(10), y
+     end function array_func2
+  end interface
+  real :: x(:)
+  x = array_func2(x, 1.0)
+end subroutine test2
+! CHECK-LABEL:   func.func @_QPtest2(
+! CHECK:           %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32
+! CHECK:           %[[VAL_4:.*]]:2 = hlfir.copy_in %{{.*}} to %{{.*}} : (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.box<!fir.array<?xf32>>, i1)
+! CHECK:           %[[VAL_5:.*]] = fir.box_addr %[[VAL_4]]#0 : (!fir.box<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
+! CHECK:           %[[VAL_6:.*]]:3 = hlfir.associate %[[VAL_3]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
+! CHECK:           %[[VAL_17:.*]] = hlfir.eval_in_mem shape %{{.*}} : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+! CHECK:           ^bb0(%[[VAL_18:.*]]: !fir.ref<!fir.array<10xf32>>):
+! CHECK:             %[[VAL_19:.*]] = fir.call @_QParray_func2(%[[VAL_5]], %[[VAL_6]]#0) fastmath<contract> : (!fir.ref<!fir.array<?xf32>>, !fir.ref<f32>) -> !fir.array<10xf32>
+! CHECK:             fir.save_result %[[VAL_19]] to %[[VAL_18]](%{{.*}}) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+! CHECK:           }
+! CHECK:           hlfir.copy_out %{{.*}}, %[[VAL_4]]#1 to %{{.*}} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, i1, !fir.box<!fir.array<?xf32>>) -> ()
+! CHECK:           hlfir.assign %[[VAL_17]] to %{{.*}} : !hlfir.expr<10xf32>, !fir.box<!fir.array<?xf32>>
+! CHECK:           hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
+! CHECK:           hlfir.destroy %[[VAL_17]] : !hlfir.expr<10xf32>
+
+subroutine test3(x)
+  interface
+     function array_func3(x)
+       real :: x, array_func3(10)
+     end function array_func3
+  end interface
+  logical :: x
+  if (any(array_func3(1.0).le.array_func3(2.0))) x = .true.
+end subroutine test3
+! CHECK-LABEL:   func.func @_QPtest3(
+! CHECK:           %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
+! CHECK:           %[[VAL_3:.*]]:3 = hlfir.associate %[[VAL_2]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
+! CHECK:           %[[VAL_14:.*]] = hlfir.eval_in_mem shape %{{.*}} : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+! CHECK:           ^bb0(%[[VAL_15:.*]]: !fir.ref<!fir.array<10xf32>>):
+! CHECK:             %[[VAL_16:.*]] = fir.call @_QParray_func3(%[[VAL_3]]#0) fastmath<contract> : (!fir.ref<f32>) -> !fir.array<10xf32>
+! CHECK:             fir.save_result %[[VAL_16]] to %[[VAL_15]](%{{.*}}) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+! CHECK:           }
+! CHECK:           %[[VAL_17:.*]] = arith.constant 2.000000e+00 : f32
+! CHECK:           %[[VAL_18:.*]]:3 = hlfir.associate %[[VAL_17]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
+! CHECK:           %[[VAL_29:.*]] = hlfir.eval_in_mem shape %{{.*}} : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+! CHECK:           ^bb0(%[[VAL_30:.*]]: !fir.ref<!fir.array<10xf32>>):
+! CHECK:             %[[VAL_31:.*]] = fir.call @_QParray_func3(%[[VAL_18]]#0) fastmath<contract> : (!fir.ref<f32>) -> !fir.array<10xf32>
+! CHECK:             fir.save_result %[[VAL_31]] to %[[VAL_30]](%{{.*}}) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+! CHECK:           }
+! CHECK:           %[[VAL_32:.*]] = hlfir.elemental %{{.*}} unordered : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+! CHECK:           ^bb0(%[[VAL_33:.*]]: index):
+! CHECK:             %[[VAL_34:.*]] = hlfir.apply %[[VAL_14]], %[[VAL_33]] : (!hlfir.expr<10xf32>, index) -> f32
+! CHECK:             %[[VAL_35:.*]] = hlfir.apply %[[VAL_29]], %[[VAL_33]] : (!hlfir.expr<10xf32>, index) -> f32
+! CHECK:             %[[VAL_36:.*]] = arith.cmpf ole, %[[VAL_34]], %[[VAL_35]] fastmath<contract> : f32
+! CHECK:             %[[VAL_37:.*]] = fir.convert %[[VAL_36]] : (i1) -> !fir.logical<4>
+! CHECK:             hlfir.yield_element %[[VAL_37]] : !fir.logical<4>
+! CHECK:           }
+! CHECK:           %[[VAL_38:.*]] = hlfir.any %[[VAL_32]] : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
+! CHECK:           hlfir.destroy %[[VAL_32]] : !hlfir.expr<?x!fir.logical<4>>
+! CHECK:           hlfir.end_associate %[[VAL_18]]#1, %[[VAL_18]]#2 : !fir.ref<f32>, i1
+! CHECK:           hlfir.destroy %[[VAL_29]] : !hlfir.expr<10xf32>
+! CHECK:           hlfir.end_associate %[[VAL_3]]#1, %[[VAL_3]]#2 : !fir.ref<f32>, i1
+! CHECK:           hlfir.destroy %[[VAL_14]] : !hlfir.expr<10xf32>
+! CHECK:           %[[VAL_39:.*]] = fir.convert %[[VAL_38]] : (!fir.logical<4>) -> i1
+! CHECK:           fir.if %[[VAL_39]] {

diff  --git a/flang/test/Lower/HLFIR/entry_return.f90 b/flang/test/Lower/HLFIR/entry_return.f90
index 5d3e160af2df6..18fb2b571b950 100644
--- a/flang/test/Lower/HLFIR/entry_return.f90
+++ b/flang/test/Lower/HLFIR/entry_return.f90
@@ -51,13 +51,13 @@ logical function f2()
 ! CHECK:           %[[VAL_6:.*]]:3 = hlfir.associate %[[VAL_4]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
 ! CHECK:           %[[VAL_7:.*]]:3 = hlfir.associate %[[VAL_5]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
 ! CHECK:           %[[VAL_8:.*]] = fir.call @_QPcomplex(%[[VAL_6]]#0, %[[VAL_7]]#0) fastmath<contract> : (!fir.ref<f32>, !fir.ref<f32>) -> f32
-! CHECK:           hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
-! CHECK:           hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<f32>, i1
 ! CHECK:           %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
 ! CHECK:           %[[VAL_10:.*]] = fir.undefined complex<f32>
 ! CHECK:           %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_8]], [0 : index] : (complex<f32>, f32) -> complex<f32>
 ! CHECK:           %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f32>, f32) -> complex<f32>
 ! CHECK:           hlfir.assign %[[VAL_12]] to %[[VAL_1]]#0 : complex<f32>, !fir.ref<complex<f32>>
+! CHECK:           hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
+! CHECK:           hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<f32>, i1
 ! CHECK:           %[[VAL_13:.*]] = fir.load %[[VAL_3]]#0 : !fir.ref<!fir.logical<4>>
 ! CHECK:           return %[[VAL_13]] : !fir.logical<4>
 ! CHECK:         }
@@ -74,13 +74,13 @@ logical function f2()
 ! CHECK:           %[[VAL_6:.*]]:3 = hlfir.associate %[[VAL_4]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
 ! CHECK:           %[[VAL_7:.*]]:3 = hlfir.associate %[[VAL_5]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
 ! CHECK:           %[[VAL_8:.*]] = fir.call @_QPcomplex(%[[VAL_6]]#0, %[[VAL_7]]#0) fastmath<contract> : (!fir.ref<f32>, !fir.ref<f32>) -> f32
-! CHECK:           hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
-! CHECK:           hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<f32>, i1
 ! CHECK:           %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
 ! CHECK:           %[[VAL_10:.*]] = fir.undefined complex<f32>
 ! CHECK:           %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_8]], [0 : index] : (complex<f32>, f32) -> complex<f32>
 ! CHECK:           %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f32>, f32) -> complex<f32>
 ! CHECK:           hlfir.assign %[[VAL_12]] to %[[VAL_1]]#0 : complex<f32>, !fir.ref<complex<f32>>
+! CHECK:           hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
+! CHECK:           hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<f32>, i1
 ! CHECK:           %[[VAL_13:.*]] = fir.load %[[VAL_1]]#0 : !fir.ref<complex<f32>>
 ! CHECK:           return %[[VAL_13]] : complex<f32>
 ! CHECK:         }

diff  --git a/flang/test/Lower/HLFIR/proc-pointer-comp-nopass.f90 b/flang/test/Lower/HLFIR/proc-pointer-comp-nopass.f90
index 28659a33d0893..206b6e4e9b797 100644
--- a/flang/test/Lower/HLFIR/proc-pointer-comp-nopass.f90
+++ b/flang/test/Lower/HLFIR/proc-pointer-comp-nopass.f90
@@ -32,8 +32,8 @@ real function test1(x)
 ! CHECK:           %[[VAL_7:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.boxproc<(!fir.ref<f32>) -> f32>>
 ! CHECK:           %[[VAL_8:.*]] = fir.box_addr %[[VAL_7]] : (!fir.boxproc<(!fir.ref<f32>) -> f32>) -> ((!fir.ref<f32>) -> f32)
 ! CHECK:           %[[VAL_9:.*]] = fir.call %[[VAL_8]](%[[VAL_5]]#0) fastmath<contract> : (!fir.ref<f32>) -> f32
-! CHECK:           hlfir.end_associate %[[VAL_5]]#1, %[[VAL_5]]#2 : !fir.ref<f32>, i1
 ! CHECK:           hlfir.assign %[[VAL_9]] to %[[VAL_2]]#0 : f32, !fir.ref<f32>
+! CHECK:           hlfir.end_associate %[[VAL_5]]#1, %[[VAL_5]]#2 : !fir.ref<f32>, i1
 
 subroutine test2(x)
   use proc_comp_defs, only : t, iface

diff  --git a/flang/test/Lower/OpenACC/acc-atomic-capture.f90 b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
index 82059908bcd0b..ee38ab6ce826a 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-capture.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
@@ -306,3 +306,60 @@ end subroutine comp_ref_in_atomic_capture2
 ! CHECK:             }
 ! CHECK:             acc.atomic.read %[[V_DECL]]#0 = %[[C]] : !fir.ref<i32>, !fir.ref<i32>, i32
 ! CHECK:           }
+
+! CHECK-LABEL:   func.func @_QPatomic_capture_with_associate() {
+subroutine atomic_capture_with_associate
+  interface
+     integer function func(x)
+       integer :: x
+     end function func
+  end interface
+! CHECK:           %[[X_DECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFatomic_capture_with_associateEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK:           %[[Y_DECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFatomic_capture_with_associateEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK:           %[[Z_DECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFatomic_capture_with_associateEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  integer :: x, y, z
+
+! CHECK:           %[[VAL_10:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK:           %[[VAL_11:.*]] = fir.call @_QPfunc(%[[VAL_10]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+! CHECK:           acc.atomic.capture {
+! CHECK:             acc.atomic.read %[[X_DECL]]#0 = %[[Y_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:             acc.atomic.write %[[Y_DECL]]#0 = %[[VAL_11]] : !fir.ref<i32>, i32
+! CHECK:           }
+! CHECK:           hlfir.end_associate %[[VAL_10]]#1, %[[VAL_10]]#2 : !fir.ref<i32>, i1
+  !$acc atomic capture
+  x = y
+  y = func(z + 1)
+  !$acc end atomic
+
+! CHECK:           %[[VAL_15:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK:           %[[VAL_16:.*]] = fir.call @_QPfunc(%[[VAL_15]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+! CHECK:           acc.atomic.capture {
+! CHECK:             acc.atomic.update %[[Y_DECL]]#0 : !fir.ref<i32> {
+! CHECK:             ^bb0(%[[VAL_17:.*]]: i32):
+! CHECK:               %[[VAL_18:.*]] = arith.muli %[[VAL_16]], %[[VAL_17]] : i32
+! CHECK:               acc.yield %[[VAL_18]] : i32
+! CHECK:             }
+! CHECK:             acc.atomic.read %[[X_DECL]]#0 = %[[Y_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:           }
+! CHECK:           hlfir.end_associate %[[VAL_15]]#1, %[[VAL_15]]#2 : !fir.ref<i32>, i1
+  !$acc atomic capture
+  y = func(z + 1) * y
+  x = y
+  !$acc end atomic
+
+! CHECK:           %[[VAL_22:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK:           %[[VAL_23:.*]] = fir.call @_QPfunc(%[[VAL_22]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+! CHECK:           acc.atomic.capture {
+! CHECK:             acc.atomic.read %[[X_DECL]]#0 = %[[Y_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:             acc.atomic.update %[[Y_DECL]]#0 : !fir.ref<i32> {
+! CHECK:             ^bb0(%[[VAL_24:.*]]: i32):
+! CHECK:               %[[VAL_25:.*]] = arith.addi %[[VAL_23]], %[[VAL_24]] : i32
+! CHECK:               acc.yield %[[VAL_25]] : i32
+! CHECK:             }
+! CHECK:           }
+! CHECK:           hlfir.end_associate %[[VAL_22]]#1, %[[VAL_22]]#2 : !fir.ref<i32>, i1
+  !$acc atomic capture
+  x = y
+  y = func(z + 1) + y
+  !$acc end atomic
+end subroutine atomic_capture_with_associate

diff  --git a/flang/test/Lower/OpenACC/acc-atomic-update.f90 b/flang/test/Lower/OpenACC/acc-atomic-update.f90
index da2972877244c..71aa69fd64eba 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-update.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-update.f90
@@ -3,6 +3,11 @@
 ! RUN: %flang_fc1 -fopenacc -emit-hlfir %s -o - | FileCheck %s
 
 program acc_atomic_update_test
+    interface
+       integer function func(x)
+         integer :: x
+       end function func
+    end interface
     integer :: x, y, z
     integer, pointer :: a, b
     integer, target :: c, d
@@ -67,7 +72,18 @@ program acc_atomic_update_test
     !$acc atomic
       i1 = i1 + 1
     !$acc end atomic
+
+!CHECK:  %[[VAL_44:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+!CHECK:  %[[VAL_45:.*]] = fir.call @_QPfunc(%[[VAL_44]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+!CHECK:  acc.atomic.update %[[X_DECL]]#0 : !fir.ref<i32> {
+!CHECK:  ^bb0(%[[VAL_46:.*]]: i32):
+!CHECK:    %[[VAL_47:.*]] = arith.addi %[[VAL_46]], %[[VAL_45]] : i32
+!CHECK:    acc.yield %[[VAL_47]] : i32
+!CHECK:  }
+!CHECK:  hlfir.end_associate %[[VAL_44]]#1, %[[VAL_44]]#2 : !fir.ref<i32>, i1
+    !$acc atomic update
+    x = x + func(z + 1)
+    !$acc end atomic
 !CHECK:  return
 !CHECK: }
 end program acc_atomic_update_test
-

diff  --git a/flang/test/Lower/OpenMP/atomic-capture.f90 b/flang/test/Lower/OpenMP/atomic-capture.f90
index bbb08220af9d9..2f800d534dc36 100644
--- a/flang/test/Lower/OpenMP/atomic-capture.f90
+++ b/flang/test/Lower/OpenMP/atomic-capture.f90
@@ -97,3 +97,59 @@ subroutine pointers_in_atomic_capture()
         b = a
     !$omp end atomic
 end subroutine
+
+! Check that the clean-ups associated with the function call
+! are generated after the omp.atomic.capture operation:
+! CHECK-LABEL:   func.func @_QPfunc_call_cleanup(
+subroutine func_call_cleanup(x, v, vv)
+  interface
+     integer function func(x)
+       integer :: x
+     end function func
+  end interface
+  integer :: x, v, vv
+
+! CHECK:           %[[VAL_7:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK:           %[[VAL_8:.*]] = fir.call @_QPfunc(%[[VAL_7]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+! CHECK:           omp.atomic.capture {
+! CHECK:             omp.atomic.read %[[VAL_1:.*]]#0 = %[[VAL_3:.*]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:             omp.atomic.write %[[VAL_3]]#0 = %[[VAL_8]] : !fir.ref<i32>, i32
+! CHECK:           }
+! CHECK:           hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<i32>, i1
+  !$omp atomic capture
+  v = x
+  x = func(vv + 1)
+  !$omp end atomic
+
+! CHECK:           %[[VAL_12:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK:           %[[VAL_13:.*]] = fir.call @_QPfunc(%[[VAL_12]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+! CHECK:           omp.atomic.capture {
+! CHECK:             omp.atomic.read %[[VAL_1]]#0 = %[[VAL_3]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:             omp.atomic.update %[[VAL_3]]#0 : !fir.ref<i32> {
+! CHECK:             ^bb0(%[[VAL_14:.*]]: i32):
+! CHECK:               %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : i32
+! CHECK:               omp.yield(%[[VAL_15]] : i32)
+! CHECK:             }
+! CHECK:           }
+! CHECK:           hlfir.end_associate %[[VAL_12]]#1, %[[VAL_12]]#2 : !fir.ref<i32>, i1
+  !$omp atomic capture
+  v = x
+  x = func(vv + 1) + x
+  !$omp end atomic
+
+! CHECK:           %[[VAL_19:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK:           %[[VAL_20:.*]] = fir.call @_QPfunc(%[[VAL_19]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
+! CHECK:           omp.atomic.capture {
+! CHECK:             omp.atomic.update %[[VAL_3]]#0 : !fir.ref<i32> {
+! CHECK:             ^bb0(%[[VAL_21:.*]]: i32):
+! CHECK:               %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : i32
+! CHECK:               omp.yield(%[[VAL_22]] : i32)
+! CHECK:             }
+! CHECK:             omp.atomic.read %[[VAL_1]]#0 = %[[VAL_3]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:           }
+! CHECK:           hlfir.end_associate %[[VAL_19]]#1, %[[VAL_19]]#2 : !fir.ref<i32>, i1
+  !$omp atomic capture
+  x = func(vv + 1) + x
+  v = x
+  !$omp end atomic
+end subroutine func_call_cleanup

diff  --git a/flang/test/Lower/OpenMP/atomic-update.f90 b/flang/test/Lower/OpenMP/atomic-update.f90
index 257ae8fb497ff..3f840acefa6e8 100644
--- a/flang/test/Lower/OpenMP/atomic-update.f90
+++ b/flang/test/Lower/OpenMP/atomic-update.f90
@@ -219,3 +219,24 @@ program OmpAtomicUpdate
   !$omp atomic update
     w = w + g  
 end program OmpAtomicUpdate
+
+! Check that the clean-ups associated with the function call
+! are generated after the omp.atomic.update operation:
+! CHECK-LABEL:   func.func @_QPfunc_call_cleanup(
+subroutine func_call_cleanup(v, vv)
+  integer v, vv
+
+! CHECK:           %[[VAL_6:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+! CHECK:           %[[VAL_7:.*]] = fir.call @_QPfunc(%[[VAL_6]]#0) fastmath<contract> : (!fir.ref<i32>) -> f32
+! CHECK:           omp.atomic.update %{{.*}} : !fir.ref<i32> {
+! CHECK:           ^bb0(%[[VAL_8:.*]]: i32):
+! CHECK:             %[[VAL_9:.*]] = fir.convert %[[VAL_8]] : (i32) -> f32
+! CHECK:             %[[VAL_10:.*]] = arith.addf %[[VAL_9]], %[[VAL_7]] fastmath<contract> : f32
+! CHECK:             %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (f32) -> i32
+! CHECK:             omp.yield(%[[VAL_11]] : i32)
+! CHECK:           }
+! CHECK:           hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<i32>, i1
+  !$omp atomic update
+  v = v + func(vv + 1)
+  !$omp end atomic
+end subroutine func_call_cleanup


        


More information about the flang-commits mailing list