[flang-commits] [flang] [flang][openacc] Convert rhs expr to the lhs type on atomic read/write (PR #70377)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Fri Oct 27 09:59:29 PDT 2023


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/70377

>From 18e8b4f0bc91068155e885af81eec5469ec4033f Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 26 Oct 2023 11:39:47 -0700
Subject: [PATCH 1/2] [flang][openacc] Convert rhs expr to the lhs type on
 atomic read/write

---
 flang/lib/Lower/DirectivesCommon.h            | 40 +++++++++++++++++++
 .../test/Lower/OpenACC/acc-atomic-capture.f90 | 25 ++++++++++++
 flang/test/Lower/OpenACC/acc-atomic-read.f90  | 14 +++++++
 3 files changed, 79 insertions(+)

diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 656fd36099b51ff..496f695e7e01f29 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -327,6 +327,36 @@ void genOmpAccAtomicWrite(Fortran::lower::AbstractConverter &converter,
                                 rightHandClauseList);
 }
 
+template <typename PtrLikeTy>
+mlir::Value convertRhs(fir::FirOpBuilder &builder, mlir::Location loc,
+                       PtrLikeTy ptrLikeType, mlir::Value rhs,
+                       mlir::Operation *captureOp) {
+  if (ptrLikeType.getElementType() != rhs.getType()) {
+    auto crtPos = builder.saveInsertionPoint();
+    builder.setInsertionPoint(captureOp);
+    mlir::Value convertedRhs =
+        builder.create<fir::ConvertOp>(loc, ptrLikeType.getElementType(), rhs);
+    builder.restoreInsertionPoint(crtPos);
+    return convertedRhs;
+  }
+  return rhs;
+}
+
+static mlir::Value addConversionIfNeeded(fir::FirOpBuilder &builder,
+                                         mlir::Location loc, mlir::Value lhs,
+                                         mlir::Value rhs,
+                                         mlir::Operation *captureOp = nullptr) {
+  if (auto ptrLikeType =
+          mlir::dyn_cast_or_null<mlir::acc::PointerLikeType>(lhs.getType()))
+    return convertRhs<mlir::acc::PointerLikeType>(builder, loc, ptrLikeType,
+                                                  rhs, captureOp);
+  if (auto ptrLikeType =
+          mlir::dyn_cast_or_null<mlir::omp::PointerLikeType>(lhs.getType()))
+    return convertRhs<mlir::omp::PointerLikeType>(builder, loc, ptrLikeType,
+                                                  rhs, captureOp);
+  return rhs;
+}
+
 /// Processes an atomic construct with read clause.
 template <typename AtomicT, typename AtomicListT>
 void genOmpAccAtomicRead(Fortran::lower::AbstractConverter &converter,
@@ -357,6 +387,13 @@ void genOmpAccAtomicRead(Fortran::lower::AbstractConverter &converter,
       fir::getBase(converter.genExprAddr(fromExpr, stmtCtx));
   mlir::Value toAddress = fir::getBase(converter.genExprAddr(
       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
+
+  mlir::Location loc = converter.getCurrentLocation();
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+
+  if (fromAddress.getType() != toAddress.getType())
+    fromAddress =
+        builder.create<fir::ConvertOp>(loc, toAddress.getType(), fromAddress);
   genOmpAccAtomicCaptureStatement(converter, fromAddress, toAddress,
                                   leftHandClauseList, rightHandClauseList,
                                   elementType);
@@ -517,6 +554,9 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
           converter, stmt1RHSArg, stmt1LHSArg,
           /*leftHandClauseList=*/nullptr,
           /*rightHandClauseList=*/nullptr, elementType);
+      stmt2RHSArg =
+          addConversionIfNeeded(firOpBuilder, currentLocation, stmt1RHSArg,
+                                stmt2RHSArg, atomicCaptureOp);
       genOmpAccAtomicWriteStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt2RHSArg,
           /*leftHandClauseList=*/nullptr,
diff --git a/flang/test/Lower/OpenACC/acc-atomic-capture.f90 b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
index 382991cf7221ba7..59c16cf61acf718 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-capture.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
@@ -97,3 +97,28 @@ subroutine pointers_in_atomic_capture()
         b = a
     !$acc end atomic
 end subroutine
+
+
+subroutine capture_with_convert_f32_to_i32()
+  implicit none
+  integer :: k, v, i
+
+  k = 1
+  v = 0
+
+  !$acc atomic capture
+  v = k
+  k = (i + 1) * 3.14
+  !$acc end atomic
+end subroutine
+
+! CHECK-LABEL: func.func @_QPcapture_with_convert_f32_to_i32()
+! CHECK: %[[K:.*]] = fir.alloca i32 {bindc_name = "k", uniq_name = "_QFcapture_with_convert_f32_to_i32Ek"}
+! CHECK: %[[V:.*]] = fir.alloca i32 {bindc_name = "v", uniq_name = "_QFcapture_with_convert_f32_to_i32Ev"}
+! CHECK: %[[CST:.*]] = arith.constant 3.140000e+00 : f32
+! CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %[[CST]] fastmath<contract> : f32
+! CHECK: %[[CONV:.*]] = fir.convert %[[MUL]] : (f32) -> i32
+! CHECK: acc.atomic.capture {
+! CHECK:   acc.atomic.read %[[V]] = %[[K]] : !fir.ref<i32>, i32
+! CHECK:   acc.atomic.write %[[K]] = %[[CONV]] : !fir.ref<i32>, i32
+! CHECK: }
diff --git a/flang/test/Lower/OpenACC/acc-atomic-read.f90 b/flang/test/Lower/OpenACC/acc-atomic-read.f90
index 28f0ce44e6f413d..3a718576124c3bb 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-read.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-read.f90
@@ -46,3 +46,17 @@ subroutine atomic_read_pointer()
 
   x = y
 end
+
+subroutine atomic_read_with_convert()
+  integer(4) :: x
+  integer(8) :: y
+
+  !$acc atomic read
+  y = x
+end
+
+! CHECK-LABEL: func.func @_QPatomic_read_with_convert() {
+! CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFatomic_read_with_convertEx"}
+! CHECK: %[[Y:.*]] = fir.alloca i64 {bindc_name = "y", uniq_name = "_QFatomic_read_with_convertEy"}
+! CHECK: %[[CONV:.*]] = fir.convert %[[X]] : (!fir.ref<i32>) -> !fir.ref<i64>
+! CHECK: acc.atomic.read %[[Y]] = %[[CONV]] : !fir.ref<i64>, i32

>From 9f337965b777f7dec4107cf519e5970bff6aad6a Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 27 Oct 2023 09:59:02 -0700
Subject: [PATCH 2/2] Use type assignment so the convert is done for us

---
 flang/lib/Lower/DirectivesCommon.h | 64 +++++-------------------------
 1 file changed, 10 insertions(+), 54 deletions(-)

diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 496f695e7e01f29..1b231ee1b891baf 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -327,36 +327,6 @@ void genOmpAccAtomicWrite(Fortran::lower::AbstractConverter &converter,
                                 rightHandClauseList);
 }
 
-template <typename PtrLikeTy>
-mlir::Value convertRhs(fir::FirOpBuilder &builder, mlir::Location loc,
-                       PtrLikeTy ptrLikeType, mlir::Value rhs,
-                       mlir::Operation *captureOp) {
-  if (ptrLikeType.getElementType() != rhs.getType()) {
-    auto crtPos = builder.saveInsertionPoint();
-    builder.setInsertionPoint(captureOp);
-    mlir::Value convertedRhs =
-        builder.create<fir::ConvertOp>(loc, ptrLikeType.getElementType(), rhs);
-    builder.restoreInsertionPoint(crtPos);
-    return convertedRhs;
-  }
-  return rhs;
-}
-
-static mlir::Value addConversionIfNeeded(fir::FirOpBuilder &builder,
-                                         mlir::Location loc, mlir::Value lhs,
-                                         mlir::Value rhs,
-                                         mlir::Operation *captureOp = nullptr) {
-  if (auto ptrLikeType =
-          mlir::dyn_cast_or_null<mlir::acc::PointerLikeType>(lhs.getType()))
-    return convertRhs<mlir::acc::PointerLikeType>(builder, loc, ptrLikeType,
-                                                  rhs, captureOp);
-  if (auto ptrLikeType =
-          mlir::dyn_cast_or_null<mlir::omp::PointerLikeType>(lhs.getType()))
-    return convertRhs<mlir::omp::PointerLikeType>(builder, loc, ptrLikeType,
-                                                  rhs, captureOp);
-  return rhs;
-}
-
 /// Processes an atomic construct with read clause.
 template <typename AtomicT, typename AtomicListT>
 void genOmpAccAtomicRead(Fortran::lower::AbstractConverter &converter,
@@ -387,10 +357,8 @@ void genOmpAccAtomicRead(Fortran::lower::AbstractConverter &converter,
       fir::getBase(converter.genExprAddr(fromExpr, stmtCtx));
   mlir::Value toAddress = fir::getBase(converter.genExprAddr(
       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
-
   mlir::Location loc = converter.getCurrentLocation();
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
-
   if (fromAddress.getType() != toAddress.getType())
     fromAddress =
         builder.create<fir::ConvertOp>(loc, toAddress.getType(), fromAddress);
@@ -464,10 +432,12 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
 
   const Fortran::parser::AssignmentStmt &stmt1 =
       std::get<typename AtomicT::Stmt1>(atomicCapture.t).v.statement;
+  const Fortran::evaluate::Assignment &assign1 = *stmt1.typedAssignment->v;
   const auto &stmt1Var{std::get<Fortran::parser::Variable>(stmt1.t)};
   const auto &stmt1Expr{std::get<Fortran::parser::Expr>(stmt1.t)};
   const Fortran::parser::AssignmentStmt &stmt2 =
       std::get<typename AtomicT::Stmt2>(atomicCapture.t).v.statement;
+  const Fortran::evaluate::Assignment &assign2 = *stmt2.typedAssignment->v;
   const auto &stmt2Var{std::get<Fortran::parser::Variable>(stmt2.t)};
   const auto &stmt2Expr{std::get<Fortran::parser::Expr>(stmt2.t)};
 
@@ -479,36 +449,25 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
   mlir::Value stmt1LHSArg, stmt1RHSArg, stmt2LHSArg, stmt2RHSArg;
   mlir::Type elementType;
   // LHS evaluations are common to all combinations of `atomic.capture`
-  stmt1LHSArg = fir::getBase(
-      converter.genExprAddr(*Fortran::semantics::GetExpr(stmt1Var), stmtCtx));
-  stmt2LHSArg = fir::getBase(
-      converter.genExprAddr(*Fortran::semantics::GetExpr(stmt2Var), stmtCtx));
+  stmt1LHSArg = fir::getBase(converter.genExprAddr(assign1.lhs, stmtCtx));
+  stmt2LHSArg = fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
 
   // Operation specific RHS evaluations
   if (checkForSingleVariableOnRHS(stmt1)) {
     // Atomic capture construct is of the form [capture-stmt, update-stmt] or
     // of the form [capture-stmt, write-stmt]
-    stmt1RHSArg = fir::getBase(converter.genExprAddr(
-        *Fortran::semantics::GetExpr(stmt1Expr), stmtCtx));
-    stmt2RHSArg = fir::getBase(converter.genExprValue(
-        *Fortran::semantics::GetExpr(stmt2Expr), stmtCtx));
-
+    stmt1RHSArg = fir::getBase(converter.genExprAddr(assign1.rhs, stmtCtx));
+    stmt2RHSArg = fir::getBase(converter.genExprValue(assign2.rhs, stmtCtx));
   } else {
     // Atomic capture construct is of the form [update-stmt, capture-stmt]
-    stmt1RHSArg = fir::getBase(converter.genExprValue(
-        *Fortran::semantics::GetExpr(stmt1Expr), stmtCtx));
-    stmt2RHSArg = fir::getBase(converter.genExprAddr(
-        *Fortran::semantics::GetExpr(stmt2Expr), stmtCtx));
+    stmt1RHSArg = fir::getBase(converter.genExprValue(assign1.rhs, stmtCtx));
+    stmt2RHSArg = fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
   }
   // Type information used in generation of `atomic.update` operation
   mlir::Type stmt1VarType =
-      fir::getBase(converter.genExprValue(
-                       *Fortran::semantics::GetExpr(stmt1Var), stmtCtx))
-          .getType();
+      fir::getBase(converter.genExprValue(assign1.lhs, stmtCtx)).getType();
   mlir::Type stmt2VarType =
-      fir::getBase(converter.genExprValue(
-                       *Fortran::semantics::GetExpr(stmt2Var), stmtCtx))
-          .getType();
+      fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType();
 
   mlir::Operation *atomicCaptureOp = nullptr;
   if constexpr (std::is_same<AtomicListT,
@@ -554,9 +513,6 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
           converter, stmt1RHSArg, stmt1LHSArg,
           /*leftHandClauseList=*/nullptr,
           /*rightHandClauseList=*/nullptr, elementType);
-      stmt2RHSArg =
-          addConversionIfNeeded(firOpBuilder, currentLocation, stmt1RHSArg,
-                                stmt2RHSArg, atomicCaptureOp);
       genOmpAccAtomicWriteStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt2RHSArg,
           /*leftHandClauseList=*/nullptr,



More information about the flang-commits mailing list