[Mlir-commits] [flang] [mlir] [flang][openacc][openmp] Support implicit casting on the atomic interface (PR #114390)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 31 04:08:22 PDT 2024


https://github.com/khaki3 created https://github.com/llvm/llvm-project/pull/114390

ACCMP atomics do not support type conversion. Specifically, I have encountered semantically incorrect code for atomic reads.

Example:

```
program main
  implicit none
  real(8) :: n
  integer :: x
  x = 1.0
  !$acc atomic capture
  n = x
  x = n
  !$acc end atomic
end program main
```

We have this error when compiling it with flang-new: `error: loc("rep1.f90":6:9): expected three operations in atomic.capture region (one terminator, and two atomic ops)`

Yet, in the following generated FIR code, we observe three issues.

1. `fir.convert` intrudes into the capture region.
2. An incorrect temporary (`%2`) is being update instead of `n`.
3. If we allow `n` in place of `%2`, the operand types of `atomic.read` do not match. Introducing a `!fir.ref<i32> -> !fir.ref<f64>` conversion on `x` is inaccurate because we need to convert the value of `x`.

```
    %2 = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32>
    %3 = "fir.alloca"() <{bindc_name = "n", in_type = f64, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFEn"}> : () -> !fir.ref<f64>
    %4:2 = "hlfir.declare"(%3) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, uniq_name = "_QFEn"}> : (!fir.ref<f64>) -> (!fir.ref<f64>, !fir.ref<f64>)
    %5 = "fir.alloca"() <{bindc_name = "x", in_type = i32, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFEx"}> : () -> !fir.ref<i32>
    %6:2 = "hlfir.declare"(%5) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, uniq_name = "_QFEx"}> : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
    %7 = "arith.constant"() <{value = 1 : i32}> : () -> i32
    "hlfir.assign"(%7, %6#0) : (i32, !fir.ref<i32>) -> ()
    %8 = "fir.load"(%4#0) : (!fir.ref<f64>) -> f64
    %9 = "fir.convert"(%8) : (f64) -> i32
    "fir.store"(%9, %2) : (i32, !fir.ref<i32>) -> ()
    %10 = "fir.load"(%6#0) : (!fir.ref<i32>) -> i32
    %11 = "fir.convert"(%10) : (i32) -> f64
    "acc.atomic.capture"() ({
      "acc.atomic.read"(%2, %6#1) <{element_type = f64}> : (!fir.ref<i32>, !fir.ref<i32>) -> ()
      %12 = "fir.convert"(%11) : (f64) -> i32
      "acc.atomic.write"(%2, %12) : (!fir.ref<i32>, i32) -> ()
      "acc.terminator"() : () -> ()
    }) : () -> ()
```

This PR updates `flang/lib/Lower/DirectivesCommon.h` to solve the issues by taking the following approaches (from top to bottom):

1. Move `fir.convert` for `atomic.write` out of the capture region.
2. Remove the `!fir.ref<i32> -> !fir.ref<f64>` conversion found in `genOmpAccAtomicRead`.
3. Eliminate unnecessary `genExprAddr` calls on the RHS, which create an invalid temporary for `x = 1.0`.
4. When generating a capture operation, refer to the original LHS instead of the type-casted RHS.

Here, we have to allow for the cases where the operand types of `atomic.read` differ from one another. Thus, this PR also removes the `AllTypesMatch` trait from both `acc.atomic.read` and `omp.atomic.read`.

The example code is converted as follows:
```
    %0 = fir.alloca f64 {bindc_name = "n", uniq_name = "_QFEn"}
    %1:2 = hlfir.declare %0 {uniq_name = "_QFEn"} : (!fir.ref<f64>) -> (!fir.ref<f64>, !fir.ref<f64>)
    %2 = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
    %3:2 = hlfir.declare %2 {uniq_name = "_QFEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
    %c1_i32 = arith.constant 1 : i32
    hlfir.assign %c1_i32 to %3#0 : i32, !fir.ref<i32>
    %4 = fir.load %1#0 : !fir.ref<f64>
    %5 = fir.convert %4 : (f64) -> i32
    acc.atomic.capture {
      acc.atomic.read %1#1 = %3#1 : !fir.ref<f64>, !fir.ref<i32>, i32
      acc.atomic.write %3#1 = %5 : !fir.ref<i32>, i32
    }
```


>From d5b5a64fdfeb3adf1caed5bf7cd382d9c407b9fa Mon Sep 17 00:00:00 2001
From: Kazuaki Matsumura <kmatsumura at nvidia.com>
Date: Thu, 31 Oct 2024 03:54:22 -0700
Subject: [PATCH] [flang][openacc][openmp] Support implicit casting on the
 atomic interface

---
 flang/lib/Lower/DirectivesCommon.h            |  55 ++++------
 .../Fir/convert-to-llvm-openmp-and-fir.fir    |   4 +-
 .../test/Lower/OpenACC/acc-atomic-capture.f90 | 103 ++++++++++++++++--
 flang/test/Lower/OpenACC/acc-atomic-read.f90  |  19 ++--
 .../Lower/OpenACC/acc-atomic-update-array.f90 |   4 +-
 flang/test/Lower/OpenMP/atomic-capture.f90    |   6 +-
 flang/test/Lower/OpenMP/atomic-read.f90       |  14 +--
 .../mlir/Dialect/OpenACC/OpenACCOps.td        |   5 +-
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |   5 +-
 9 files changed, 145 insertions(+), 70 deletions(-)

diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 421a44b128c017..88514b16743278 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -179,7 +179,11 @@ static inline void genOmpAccAtomicWriteStatement(
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
 
   mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
+  // Create a conversion outside the capture block.
+  auto insertionPoint = firOpBuilder.saveInsertionPoint();
+  firOpBuilder.setInsertionPointAfter(rhsExpr.getDefiningOp());
   rhsExpr = firOpBuilder.createConvert(loc, varType, rhsExpr);
+  firOpBuilder.restoreInsertionPoint(insertionPoint);
 
   processOmpAtomicTODO<AtomicListT>(varType, loc);
 
@@ -410,10 +414,6 @@ void genOmpAccAtomicRead(Fortran::lower::AbstractConverter &converter,
       fir::getBase(converter.genExprAddr(fromExpr, stmtCtx));
   mlir::Value toAddress = fir::getBase(converter.genExprAddr(
       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
-  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
-  if (fromAddress.getType() != toAddress.getType())
-    fromAddress =
-        builder.create<fir::ConvertOp>(loc, toAddress.getType(), fromAddress);
   genOmpAccAtomicCaptureStatement(converter, fromAddress, toAddress,
                                   leftHandClauseList, rightHandClauseList,
                                   elementType, loc);
@@ -497,23 +497,12 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
   // a `atomic.read`, `atomic.write`, or `atomic.update` operation
   // inside `atomic.capture`
   Fortran::lower::StatementContext stmtCtx;
-  mlir::Value stmt1LHSArg, stmt1RHSArg, stmt2LHSArg, stmt2RHSArg;
-  mlir::Type elementType;
   // LHS evaluations are common to all combinations of `atomic.capture`
-  stmt1LHSArg = fir::getBase(converter.genExprAddr(assign1.lhs, stmtCtx));
-  stmt2LHSArg = fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
+  mlir::Value stmt1LHSArg =
+      fir::getBase(converter.genExprAddr(assign1.lhs, stmtCtx));
+  mlir::Value stmt2LHSArg =
+      fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
 
-  // Operation specific RHS evaluations
-  if (Fortran::semantics::checkForSingleVariableOnRHS(stmt1)) {
-    // Atomic capture construct is of the form [capture-stmt, update-stmt] or
-    // of the form [capture-stmt, write-stmt]
-    stmt1RHSArg = fir::getBase(converter.genExprAddr(assign1.rhs, stmtCtx));
-    stmt2RHSArg = fir::getBase(converter.genExprValue(assign2.rhs, stmtCtx));
-  } else {
-    // Atomic capture construct is of the form [update-stmt, capture-stmt]
-    stmt1RHSArg = fir::getBase(converter.genExprValue(assign1.rhs, stmtCtx));
-    stmt2RHSArg = fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
-  }
   // Type information used in generation of `atomic.update` operation
   mlir::Type stmt1VarType =
       fir::getBase(converter.genExprValue(assign1.lhs, stmtCtx)).getType();
@@ -545,44 +534,46 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
       // Atomic capture construct is of the form [capture-stmt, update-stmt]
       const Fortran::semantics::SomeExpr &fromExpr =
           *Fortran::semantics::GetExpr(stmt1Expr);
-      elementType = converter.genType(fromExpr);
+      mlir::Type elementType = converter.genType(fromExpr);
       genOmpAccAtomicCaptureStatement<AtomicListT>(
-          converter, stmt1RHSArg, stmt1LHSArg,
+          converter, stmt2LHSArg, stmt1LHSArg,
           /*leftHandClauseList=*/nullptr,
           /*rightHandClauseList=*/nullptr, elementType, loc);
       genOmpAccAtomicUpdateStatement<AtomicListT>(
-          converter, stmt1RHSArg, stmt2VarType, stmt2Var, stmt2Expr,
+          converter, stmt2LHSArg, stmt2VarType, stmt2Var, stmt2Expr,
           /*leftHandClauseList=*/nullptr,
           /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
     } else {
       // Atomic capture construct is of the form [capture-stmt, write-stmt]
+      firOpBuilder.setInsertionPoint(atomicCaptureOp);
+      mlir::Value stmt2RHSArg =
+          fir::getBase(converter.genExprValue(assign2.rhs, stmtCtx));
+      firOpBuilder.setInsertionPointToStart(&block);
       const Fortran::semantics::SomeExpr &fromExpr =
           *Fortran::semantics::GetExpr(stmt1Expr);
-      elementType = converter.genType(fromExpr);
+      mlir::Type elementType = converter.genType(fromExpr);
       genOmpAccAtomicCaptureStatement<AtomicListT>(
-          converter, stmt1RHSArg, stmt1LHSArg,
+          converter, stmt2LHSArg, stmt1LHSArg,
           /*leftHandClauseList=*/nullptr,
           /*rightHandClauseList=*/nullptr, elementType, loc);
       genOmpAccAtomicWriteStatement<AtomicListT>(
-          converter, stmt1RHSArg, stmt2RHSArg,
+          converter, stmt2LHSArg, stmt2RHSArg,
           /*leftHandClauseList=*/nullptr,
           /*rightHandClauseList=*/nullptr, loc);
     }
   } else {
     // Atomic capture construct is of the form [update-stmt, capture-stmt]
-    firOpBuilder.setInsertionPointToEnd(&block);
     const Fortran::semantics::SomeExpr &fromExpr =
         *Fortran::semantics::GetExpr(stmt2Expr);
-    elementType = converter.genType(fromExpr);
-    genOmpAccAtomicCaptureStatement<AtomicListT>(
-        converter, stmt1LHSArg, stmt2LHSArg,
-        /*leftHandClauseList=*/nullptr,
-        /*rightHandClauseList=*/nullptr, elementType, loc);
-    firOpBuilder.setInsertionPointToStart(&block);
+    mlir::Type elementType = converter.genType(fromExpr);
     genOmpAccAtomicUpdateStatement<AtomicListT>(
         converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
         /*leftHandClauseList=*/nullptr,
         /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
+    genOmpAccAtomicCaptureStatement<AtomicListT>(
+        converter, stmt1LHSArg, stmt2LHSArg,
+        /*leftHandClauseList=*/nullptr,
+        /*rightHandClauseList=*/nullptr, elementType, loc);
   }
   firOpBuilder.setInsertionPointToEnd(&block);
   if constexpr (std::is_same<AtomicListT,
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index 168526518865b4..184abe24fe967d 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -781,11 +781,11 @@ func.func @_QPsimple_reduction(%arg0: !fir.ref<!fir.array<100x!fir.logical<4>>>
 // -----
 
 // CHECK: llvm.func @_QPs
-// CHECK: omp.atomic.read %{{.*}} = %{{.*}}   : !llvm.ptr, !llvm.struct<(f32, f32)>
+// CHECK: omp.atomic.read %{{.*}} = %{{.*}}   : !llvm.ptr, !llvm.ptr, !llvm.struct<(f32, f32)>
 
 func.func @_QPs(%arg0: !fir.ref<complex<f32>> {fir.bindc_name = "x"}) {
   %0 = fir.alloca complex<f32> {bindc_name = "v", uniq_name = "_QFsEv"}
-  omp.atomic.read %0 = %arg0   : !fir.ref<complex<f32>>, complex<f32>
+  omp.atomic.read %0 = %arg0   : !fir.ref<complex<f32>>, !fir.ref<complex<f32>>, complex<f32>
   return
 }
 
diff --git a/flang/test/Lower/OpenACC/acc-atomic-capture.f90 b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
index 373683386fda90..66b8e8c5843a81 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-capture.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
@@ -11,7 +11,7 @@ program acc_atomic_capture_test
 !CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %2 {uniq_name = "_QFEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK: %[[temp:.*]] = fir.load %[[X_DECL]]#0 : !fir.ref<i32>
 !CHECK: acc.atomic.capture {
-!CHECK: acc.atomic.read %[[X_DECL]]#1 = %[[Y_DECL]]#1 : !fir.ref<i32>
+!CHECK: acc.atomic.read %[[X_DECL]]#1 = %[[Y_DECL]]#1 : !fir.ref<i32>, !fir.ref<i32>, i32
 !CHECK: acc.atomic.update %[[Y_DECL]]#1 : !fir.ref<i32> {
 !CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK: %[[result:.*]] = arith.addi %[[temp]], %[[ARG]] : i32
@@ -32,7 +32,7 @@ program acc_atomic_capture_test
 !CHECK: %[[result:.*]] = arith.muli %[[temp]], %[[ARG]] : i32
 !CHECK: acc.yield %[[result]] : i32
 !CHECK: }
-!CHECK: acc.atomic.read %[[X_DECL]]#1 = %[[Y_DECL]]#1 : !fir.ref<i32>
+!CHECK: acc.atomic.read %[[X_DECL]]#1 = %[[Y_DECL]]#1 : !fir.ref<i32>, !fir.ref<i32>, i32
 !CHECK: }
 
     !$acc atomic capture
@@ -47,7 +47,7 @@ program acc_atomic_capture_test
 !CHECK: %[[result_noreassoc:.*]] = hlfir.no_reassoc %[[result]] : i32
 !CHECK: %[[result:.*]] = arith.addi %[[constant_20]], %[[result_noreassoc]] : i32
 !CHECK: acc.atomic.capture {
-!CHECK: acc.atomic.read %[[X_DECL]]#1 = %[[Y_DECL]]#1 : !fir.ref<i32>
+!CHECK: acc.atomic.read %[[X_DECL]]#1 = %[[Y_DECL]]#1 : !fir.ref<i32>, !fir.ref<i32>, i32
 !CHECK: acc.atomic.write %[[Y_DECL]]#1 = %[[result]] : !fir.ref<i32>, i32
 !CHECK: }
 
@@ -82,7 +82,7 @@ subroutine pointers_in_atomic_capture()
 !CHECK: %[[result:.*]] = arith.addi %[[ARG]], %[[loaded_value]] : i32
 !CHECK: acc.yield %[[result]] : i32
 !CHECK: }
-!CHECK: acc.atomic.read %[[loaded_B_addr]] = %[[loaded_A_addr]] : !fir.ptr<i32>, i32
+!CHECK: acc.atomic.read %[[loaded_B_addr]] = %[[loaded_A_addr]] : !fir.ptr<i32>, !fir.ptr<i32>, i32
 !CHECK: }
     integer, pointer :: a, b
     integer, target :: c, d
@@ -118,10 +118,95 @@ subroutine capture_with_convert_f32_to_i32()
 ! CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %[[CST]] fastmath<contract> : f32
 ! CHECK: %[[CONV:.*]] = fir.convert %[[MUL]] : (f32) -> i32
 ! CHECK: acc.atomic.capture {
-! CHECK:   acc.atomic.read %[[V_DECL]]#1 = %[[K_DECL]]#1 : !fir.ref<i32>, i32
+! CHECK:   acc.atomic.read %[[V_DECL]]#1 = %[[K_DECL]]#1 : !fir.ref<i32>, !fir.ref<i32>, i32
 ! CHECK:   acc.atomic.write %[[K_DECL]]#1 = %[[CONV]] : !fir.ref<i32>, i32
 ! CHECK: }
 
+subroutine capture_with_convert_i32_to_f64()
+  real(8) :: x
+  integer :: v
+  x = 1.0
+  v = 0
+  !$acc atomic capture
+  v = x
+  x = v
+  !$acc end atomic
+end subroutine capture_with_convert_i32_to_f64
+
+! CHECK-LABEL: func.func @_QPcapture_with_convert_i32_to_f64()
+! CHECK: %[[V:.*]] = fir.alloca i32 {bindc_name = "v", uniq_name = "_QFcapture_with_convert_i32_to_f64Ev"}
+! CHECK: %[[V_DECL:.*]]:2 = hlfir.declare %[[V]] {uniq_name = "_QFcapture_with_convert_i32_to_f64Ev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[X:.*]] = fir.alloca f64 {bindc_name = "x", uniq_name = "_QFcapture_with_convert_i32_to_f64Ex"}
+! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {uniq_name = "_QFcapture_with_convert_i32_to_f64Ex"} : (!fir.ref<f64>) -> (!fir.ref<f64>, !fir.ref<f64>)
+! CHECK: %[[LOAD:.*]] = fir.load %[[V_DECL]]#0 : !fir.ref<i32>
+! CHECK: %[[CONV:.*]] = fir.convert %[[LOAD]] : (i32) -> f64
+! CHECK: acc.atomic.capture {
+! CHECK:   acc.atomic.read %[[V_DECL]]#1 = %[[X_DECL]]#1 : !fir.ref<i32>, !fir.ref<f64>, f64
+! CHECK:   acc.atomic.write %[[X_DECL]]#1 = %[[CONV]] : !fir.ref<f64>, f64
+! CHECK: }
+
+subroutine capture_with_convert_f64_to_i32()
+  integer :: x
+  real(8) :: v
+  x = 1
+  v = 0
+  !$acc atomic capture
+  x = v * v
+  v = x
+  !$acc end atomic
+end subroutine capture_with_convert_f64_to_i32
+
+! CHECK-LABEL: func.func @_QPcapture_with_convert_f64_to_i32()
+! CHECK: %[[V:.*]] = fir.alloca f64 {bindc_name = "v", uniq_name = "_QFcapture_with_convert_f64_to_i32Ev"}
+! CHECK: %[[V_DECL:.*]]:2 = hlfir.declare %[[V]] {uniq_name = "_QFcapture_with_convert_f64_to_i32Ev"} : (!fir.ref<f64>) -> (!fir.ref<f64>, !fir.ref<f64>)
+! CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFcapture_with_convert_f64_to_i32Ex"}
+! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {uniq_name = "_QFcapture_with_convert_f64_to_i32Ex"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %c1_i32 = arith.constant 1 : i32
+! CHECK: hlfir.assign %c1_i32 to %[[X_DECL]]#0 : i32, !fir.ref<i32>
+! CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f64
+! CHECK: hlfir.assign %[[CST]] to %[[V_DECL]]#0 : f64, !fir.ref<f64>
+! CHECK: %[[LOAD:.*]] = fir.load %[[V_DECL]]#0 : !fir.ref<f64>
+! CHECK: acc.atomic.capture {
+! CHECK:   acc.atomic.update %[[X_DECL]]#1 : !fir.ref<i32> {
+! CHECK:   ^bb0(%arg0: i32):
+! CHECK:     %[[MUL:.*]] = arith.mulf %[[LOAD]], %[[LOAD]] fastmath<contract> : f64
+! CHECK:     %[[CONV:.*]] = fir.convert %[[MUL]] : (f64) -> i32
+! CHECK:     acc.yield %[[CONV]] : i32
+! CHECK:   }
+! CHECK:   acc.atomic.read %[[V_DECL]]#1 = %[[X_DECL]]#1 : !fir.ref<f64>, !fir.ref<i32>, i32
+! CHECK: }
+
+subroutine capture_with_convert_i32_to_f32()
+  real(4) :: x
+  integer :: v
+  x = 1.0
+  v = 0
+  !$acc atomic capture
+  v = x
+  x = x + v
+  !$acc end atomic
+end subroutine capture_with_convert_i32_to_f32
+
+! CHECK-LABEL: func.func @_QPcapture_with_convert_i32_to_f32()
+! CHECK: %[[V:.*]] = fir.alloca i32 {bindc_name = "v", uniq_name = "_QFcapture_with_convert_i32_to_f32Ev"}
+! CHECK: %[[V_DECL:.*]]:2 = hlfir.declare %[[V]] {uniq_name = "_QFcapture_with_convert_i32_to_f32Ev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[X:.*]] = fir.alloca f32 {bindc_name = "x", uniq_name = "_QFcapture_with_convert_i32_to_f32Ex"}
+! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {uniq_name = "_QFcapture_with_convert_i32_to_f32Ex"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+! CHECK: hlfir.assign %[[CST]] to %[[X_DECL]]#0 : f32, !fir.ref<f32>
+! CHECK: %c0_i32 = arith.constant 0 : i32
+! CHECK: hlfir.assign %c0_i32 to %[[V_DECL]]#0 : i32, !fir.ref<i32>
+! CHECK: %[[LOAD:.*]] = fir.load %[[V_DECL]]#0 : !fir.ref<i32>
+! CHECK: acc.atomic.capture {
+! CHECK:   acc.atomic.read %[[V_DECL]]#1 = %[[X_DECL]]#1 : !fir.ref<i32>, !fir.ref<f32>, f32
+! CHECK:   acc.atomic.update %[[X_DECL]]#1 : !fir.ref<f32> {
+! CHECK:   ^bb0(%arg0: f32):
+! CHECK:     %[[CONV:.*]] = fir.convert %[[LOAD]] : (i32) -> f32
+! CHECK:     %[[ADD:.*]] = arith.addf %arg0, %[[CONV]] fastmath<contract> : f32
+! CHECK:     acc.yield %[[ADD]] : f32
+! CHECK:   }
+! CHECK: }
+
 subroutine array_ref_in_atomic_capture1
   integer :: x(10), v
   !$acc atomic capture
@@ -136,7 +221,7 @@ end subroutine array_ref_in_atomic_capture1
 ! CHECK:           %[[X_DECL:.*]]:2 = hlfir.declare %[[X]](%{{.*}}) {uniq_name = "_QFarray_ref_in_atomic_capture1Ex"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
 ! CHECK:           %[[X_REF:.*]] = hlfir.designate %[[X_DECL]]#0 (%{{.*}})  : (!fir.ref<!fir.array<10xi32>>, index) -> !fir.ref<i32>
 ! CHECK:           acc.atomic.capture {
-! CHECK:             acc.atomic.read %[[V_DECL]]#1 = %[[X_REF]] : !fir.ref<i32>, i32
+! CHECK:             acc.atomic.read %[[V_DECL]]#1 = %[[X_REF]] : !fir.ref<i32>, !fir.ref<i32>, i32
 ! CHECK:             acc.atomic.update %[[X_REF]] : !fir.ref<i32> {
 ! CHECK:             ^bb0(%[[VAL_7:.*]]: i32):
 ! CHECK:               %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %{{.*}} : i32
@@ -163,7 +248,7 @@ end subroutine array_ref_in_atomic_capture2
 ! CHECK:               %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %{{.*}} : i32
 ! CHECK:               acc.yield %[[VAL_8]] : i32
 ! CHECK:             }
-! CHECK:             acc.atomic.read %[[V_DECL]]#1 = %[[X_REF]] : !fir.ref<i32>, i32
+! CHECK:             acc.atomic.read %[[V_DECL]]#1 = %[[X_REF]] : !fir.ref<i32>, !fir.ref<i32>, i32
 ! CHECK:           }
 
 subroutine comp_ref_in_atomic_capture1
@@ -184,7 +269,7 @@ end subroutine comp_ref_in_atomic_capture1
 ! CHECK:           %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {uniq_name = "_QFcomp_ref_in_atomic_capture1Ex"} : (!fir.ref<!fir.type<_QFcomp_ref_in_atomic_capture1Tt1{c:i32}>>) -> (!fir.ref<!fir.type<_QFcomp_ref_in_atomic_capture1Tt1{c:i32}>>, !fir.ref<!fir.type<_QFcomp_ref_in_atomic_capture1Tt1{c:i32}>>)
 ! CHECK:           %[[C:.*]] = hlfir.designate %[[X_DECL]]#0{"c"}   : (!fir.ref<!fir.type<_QFcomp_ref_in_atomic_capture1Tt1{c:i32}>>) -> !fir.ref<i32>
 ! CHECK:           acc.atomic.capture {
-! CHECK:             acc.atomic.read %[[V_DECL]]#1 = %[[C]] : !fir.ref<i32>, i32
+! CHECK:             acc.atomic.read %[[V_DECL]]#1 = %[[C]] : !fir.ref<i32>, !fir.ref<i32>, i32
 ! CHECK:             acc.atomic.update %[[C]] : !fir.ref<i32> {
 ! CHECK:             ^bb0(%[[VAL_5:.*]]: i32):
 ! CHECK:               %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %{{.*}} : i32
@@ -215,5 +300,5 @@ end subroutine comp_ref_in_atomic_capture2
 ! CHECK:               %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %{{.*}} : i32
 ! CHECK:               acc.yield %[[VAL_6]] : i32
 ! CHECK:             }
-! CHECK:             acc.atomic.read %[[V_DECL]]#1 = %[[C]] : !fir.ref<i32>, i32
+! CHECK:             acc.atomic.read %[[V_DECL]]#1 = %[[C]] : !fir.ref<i32>, !fir.ref<i32>, i32
 ! CHECK:           }
diff --git a/flang/test/Lower/OpenACC/acc-atomic-read.f90 b/flang/test/Lower/OpenACC/acc-atomic-read.f90
index c1a97a9e5f74f3..f2cbe6e45596a4 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-read.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-read.f90
@@ -13,7 +13,7 @@ end program acc_atomic_test
 ! CHECK: %[[G_DECL:.*]]:2 = hlfir.declare %[[VAR_G]] {uniq_name = "_QFEg"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
 ! CHECK: %[[VAR_H:.*]] = fir.alloca f32 {bindc_name = "h", uniq_name = "_QFEh"}
 ! CHECK: %[[H_DECL:.*]]:2 = hlfir.declare %[[VAR_H]] {uniq_name = "_QFEh"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
-! CHECK: acc.atomic.read %[[G_DECL]]#1 = %[[H_DECL]]#1 : !fir.ref<f32>, f32
+! CHECK: acc.atomic.read %[[G_DECL]]#1 = %[[H_DECL]]#1 : !fir.ref<f32>, !fir.ref<f32>, f32
 ! CHECK: return
 ! CHECK: }
 
@@ -39,10 +39,10 @@ subroutine atomic_read_pointer()
 ! CHECK:   %[[BOX_ADDR_X:.*]] = fir.box_addr %[[LOAD_X]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
 ! CHECK:   %[[LOAD_Y:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
 ! CHECK:   %[[BOX_ADDR_Y:.*]] = fir.box_addr %[[LOAD_Y]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
-! CHECK:   acc.atomic.read %[[BOX_ADDR_Y]] = %[[BOX_ADDR_X]] : !fir.ptr<i32>, i32
+! CHECK:   acc.atomic.read %[[BOX_ADDR_Y]] = %[[BOX_ADDR_X]] : !fir.ptr<i32>, !fir.ptr<i32>, i32
 ! CHECK: }
 
-subroutine atomic_read_with_convert()
+subroutine atomic_read_with_cast()
   integer(4) :: x
   integer(8) :: y
 
@@ -50,10 +50,9 @@ subroutine atomic_read_with_convert()
   y = x
 end
 
-! CHECK-LABEL: func.func @_QPatomic_read_with_convert() {
-! CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFatomic_read_with_convertEx"}
-! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {uniq_name = "_QFatomic_read_with_convertEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-! CHECK: %[[Y:.*]] = fir.alloca i64 {bindc_name = "y", uniq_name = "_QFatomic_read_with_convertEy"}
-! CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {uniq_name = "_QFatomic_read_with_convertEy"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
-! CHECK: %[[CONV:.*]] = fir.convert %[[X_DECL]]#1 : (!fir.ref<i32>) -> !fir.ref<i64>
-! CHECK: acc.atomic.read %[[Y_DECL]]#1 = %[[CONV]] : !fir.ref<i64>, i32
+! CHECK-LABEL: func.func @_QPatomic_read_with_cast() {
+! CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFatomic_read_with_castEx"}
+! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {uniq_name = "_QFatomic_read_with_castEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[Y:.*]] = fir.alloca i64 {bindc_name = "y", uniq_name = "_QFatomic_read_with_castEy"}
+! CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {uniq_name = "_QFatomic_read_with_castEy"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK: acc.atomic.read %[[Y_DECL]]#1 = %[[X_DECL]]#1 : !fir.ref<i64>, !fir.ref<i32>, i32
diff --git a/flang/test/Lower/OpenACC/acc-atomic-update-array.f90 b/flang/test/Lower/OpenACC/acc-atomic-update-array.f90
index eeb7ea29940862..f89a9ab457d499 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-update-array.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-update-array.f90
@@ -45,7 +45,7 @@ subroutine atomic_read_array1(r, n, x)
 ! CHECK: %[[DECL_X:.*]]:2 = hlfir.declare %[[ARG2]] dummy_scope %{{[0-9]+}} {uniq_name = "_QFatomic_read_array1Ex"} : (!fir.ref<f32>, !fir.dscope) -> (!fir.ref<f32>, !fir.ref<f32>)
 ! CHECK: %[[DECL_R:.*]]:2 = hlfir.declare %[[ARG0]](%{{.*}}) dummy_scope %{{[0-9]+}} {uniq_name = "_QFatomic_read_array1Er"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
 ! CHECK: %[[DES:.*]] = hlfir.designate %[[DECL_R]]#0 (%{{.*}})  : (!fir.box<!fir.array<?xf32>>, i64) -> !fir.ref<f32>
-! CHECK: acc.atomic.read %[[DECL_X]]#1 = %[[DES]] : !fir.ref<f32>, f32
+! CHECK: acc.atomic.read %[[DECL_X]]#1 = %[[DES]] : !fir.ref<f32>, !fir.ref<f32>, f32
 
 subroutine atomic_write_array1(r, n, x)
   implicit none
@@ -88,5 +88,5 @@ subroutine atomic_capture_array1(r, n, x, y)
 ! CHECK:     %[[ADD:.*]] = arith.addf %[[ARG]], %[[LOAD]] fastmath<contract> : f32
 ! CHECK:     acc.yield %[[ADD]] : f32
 ! CHECK:   }
-! CHECK:   acc.atomic.read %[[DECL_Y]]#1 = %[[R_I]] : !fir.ref<f32>, f32
+! CHECK:   acc.atomic.read %[[DECL_Y]]#1 = %[[R_I]] : !fir.ref<f32>, !fir.ref<f32>, f32
 ! CHECK: }
diff --git a/flang/test/Lower/OpenMP/atomic-capture.f90 b/flang/test/Lower/OpenMP/atomic-capture.f90
index af82e4b2a20eb2..679d22d3d7063e 100644
--- a/flang/test/Lower/OpenMP/atomic-capture.f90
+++ b/flang/test/Lower/OpenMP/atomic-capture.f90
@@ -22,7 +22,7 @@ program OmpAtomicCapture
 !CHECK: %[[TEMP:.*]] = arith.muli %[[VAL_Y_LOADED]], %[[ARG]] : i32
 !CHECK: omp.yield(%[[TEMP]] : i32)
 !CHECK: }
-!CHECK: omp.atomic.read %[[VAL_X_DECLARE]]#1 = %[[VAL_Y_DECLARE]]#1 : !fir.ref<i32>, i32
+!CHECK: omp.atomic.read %[[VAL_X_DECLARE]]#1 = %[[VAL_Y_DECLARE]]#1 : !fir.ref<i32>, !fir.ref<i32>, i32
 !CHECK: }
     !$omp atomic hint(omp_sync_hint_uncontended) capture
         y = x * y 
@@ -36,7 +36,7 @@ program OmpAtomicCapture
 !CHECK: %[[NO_REASSOC:.*]] = hlfir.no_reassoc %[[SUB]] : i32
 !CHECK: %[[ADD:.*]] = arith.addi  %[[VAL_20]], %[[NO_REASSOC]] : i32
 !CHECK: omp.atomic.capture hint(nonspeculative) memory_order(acquire) {
-!CHECK:   omp.atomic.read %[[VAL_X_DECLARE]]#1 = %[[VAL_Y_DECLARE]]#1 : !fir.ref<i32>, i32
+!CHECK:   omp.atomic.read %[[VAL_X_DECLARE]]#1 = %[[VAL_Y_DECLARE]]#1 : !fir.ref<i32>, !fir.ref<i32>, i32
 !CHECK:   omp.atomic.write %[[VAL_Y_DECLARE]]#1 = %[[ADD]] : !fir.ref<i32>, i32
 !CHECK: }
 !CHECK: return
@@ -88,7 +88,7 @@ subroutine pointers_in_atomic_capture()
 !CHECK: %[[TEMP:.*]] = arith.addi %[[ARG]], %[[VAL_B]] : i32
 !CHECK: omp.yield(%[[TEMP]] : i32)
 !CHECK: }
-!CHECK: omp.atomic.read %[[VAL_B_BOX_ADDR]] = %[[VAL_A_BOX_ADDR]] : !fir.ptr<i32>, i32
+!CHECK: omp.atomic.read %[[VAL_B_BOX_ADDR]] = %[[VAL_A_BOX_ADDR]] : !fir.ptr<i32>, !fir.ptr<i32>, i32
 !CHECK: }
 !CHECK: return
 !CHECK: }
diff --git a/flang/test/Lower/OpenMP/atomic-read.f90 b/flang/test/Lower/OpenMP/atomic-read.f90
index c3270dd6c1d670..e9bea42252faa3 100644
--- a/flang/test/Lower/OpenMP/atomic-read.f90
+++ b/flang/test/Lower/OpenMP/atomic-read.f90
@@ -25,12 +25,12 @@
 !CHECK:    %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:    %[[Y_REF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"}
 !CHECK:    %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_REF]] {uniq_name = "_QFEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-!CHECK:    omp.atomic.read %[[X_DECL]]#1 = %[[Y_DECL]]#1   hint(uncontended) memory_order(acquire) : !fir.ref<i32>, i32
-!CHECK:    omp.atomic.read %[[A_DECL]]#1 = %[[B_DECL]]#1   memory_order(relaxed) : !fir.ref<i32>, i32
-!CHECK:    omp.atomic.read %[[C_DECL]]#1 = %[[D_DECL]]#1   hint(contended) memory_order(seq_cst) : !fir.ref<!fir.logical<4>>, !fir.logical<4>
-!CHECK:    omp.atomic.read %[[E_DECL]]#1 = %[[F_DECL]]#1   hint(speculative) : !fir.ref<i32>, i32
-!CHECK:    omp.atomic.read %[[G_DECL]]#1 = %[[H_DECL]]#1   hint(nonspeculative) : !fir.ref<f32>, f32
-!CHECK:    omp.atomic.read %[[G_DECL]]#1 = %[[H_DECL]]#1   : !fir.ref<f32>, f32
+!CHECK:    omp.atomic.read %[[X_DECL]]#1 = %[[Y_DECL]]#1   hint(uncontended) memory_order(acquire) : !fir.ref<i32>, !fir.ref<i32>, i32
+!CHECK:    omp.atomic.read %[[A_DECL]]#1 = %[[B_DECL]]#1   memory_order(relaxed) : !fir.ref<i32>, !fir.ref<i32>, i32
+!CHECK:    omp.atomic.read %[[C_DECL]]#1 = %[[D_DECL]]#1   hint(contended) memory_order(seq_cst) : !fir.ref<!fir.logical<4>>, !fir.ref<!fir.logical<4>>, !fir.logical<4>
+!CHECK:    omp.atomic.read %[[E_DECL]]#1 = %[[F_DECL]]#1   hint(speculative) : !fir.ref<i32>, !fir.ref<i32>, i32
+!CHECK:    omp.atomic.read %[[G_DECL]]#1 = %[[H_DECL]]#1   hint(nonspeculative) : !fir.ref<f32>, !fir.ref<f32>, f32
+!CHECK:    omp.atomic.read %[[G_DECL]]#1 = %[[H_DECL]]#1   : !fir.ref<f32>, !fir.ref<f32>, f32
 
 program OmpAtomic
 
@@ -68,7 +68,7 @@ end program OmpAtomic
 !CHECK:    %[[X_POINTEE_ADDR:.*]] = fir.box_addr %[[X_ADDR]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
 !CHECK:    %[[Y_ADDR:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
 !CHECK:    %[[Y_POINTEE_ADDR:.*]] = fir.box_addr %[[Y_ADDR]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
-!CHECK:    omp.atomic.read %[[Y_POINTEE_ADDR]] = %[[X_POINTEE_ADDR]]   : !fir.ptr<i32>, i32
+!CHECK:    omp.atomic.read %[[Y_POINTEE_ADDR]] = %[[X_POINTEE_ADDR]]   : !fir.ptr<i32>, !fir.ptr<i32>, i32
 !CHECK:    %[[Y_ADDR:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
 !CHECK:    %[[Y_POINTEE_ADDR:.*]] = fir.box_addr %[[Y_ADDR]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
 !CHECK:    %[[Y_POINTEE_VAL:.*]] = fir.load %[[Y_POINTEE_ADDR]] : !fir.ptr<i32>
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index d9f38259c0ace0..6115613a3dbaae 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1948,8 +1948,7 @@ def OpenACC_YieldOp : OpenACC_Op<"yield", [Pure, ReturnLike, Terminator,
 // 2.12 atomic construct
 //===----------------------------------------------------------------------===//
 
-def AtomicReadOp : OpenACC_Op<"atomic.read", [AllTypesMatch<["x", "v"]>,
-                                              AtomicReadOpInterface]> {
+def AtomicReadOp : OpenACC_Op<"atomic.read", [AtomicReadOpInterface]> {
 
   let summary = "performs an atomic read";
 
@@ -1965,7 +1964,7 @@ def AtomicReadOp : OpenACC_Op<"atomic.read", [AllTypesMatch<["x", "v"]>,
                        TypeAttr:$element_type);
   let assemblyFormat = [{
     $v `=` $x
-    `:` type($x) `,` $element_type attr-dict
+    `:` type($v) `,` type($x) `,` $element_type attr-dict
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 626539cb7bde42..5fd8184fe0e0f7 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1286,7 +1286,7 @@ def TaskwaitOp : OpenMP_Op<"taskwait", clauses = [
 // two-step process.
 
 def AtomicReadOp : OpenMP_Op<"atomic.read", traits = [
-    AllTypesMatch<["x", "v"]>, AtomicReadOpInterface
+    AtomicReadOpInterface
   ], clauses = [
     OpenMP_HintClause, OpenMP_MemoryOrderClause
   ]> {
@@ -1304,7 +1304,8 @@ def AtomicReadOp : OpenMP_Op<"atomic.read", traits = [
 
   // Override clause-based assemblyFormat.
   let assemblyFormat = "$v `=` $x" # clausesReqAssemblyFormat # " oilist(" #
-    clausesOptAssemblyFormat # ") `:` type($x) `,` $element_type attr-dict";
+    clausesOptAssemblyFormat #
+    ") `:` type($v) `,` type($x) `,` $element_type attr-dict";
 
   let extraClassDeclaration = [{
     /// The number of variable operands.



More information about the Mlir-commits mailing list