[flang] [llvm] [flang][llvm][OpenMP][OpenACC] Add implicit casts to omp.atomic and acc.atomic (PR #131603)

via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 17 04:41:31 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (NimishMishra)

<details>
<summary>Changes</summary>

Currently, implicit casts in Fortran are handled by the OMPIRBuilder. This patch shifts that responsibility to FIR codegen. 

---
Full diff: https://github.com/llvm/llvm-project/pull/131603.diff


5 Files Affected:

- (modified) flang/include/flang/Lower/DirectivesCommon.h (+84-6) 
- (modified) flang/test/Lower/OpenACC/acc-atomic-capture.f90 (+7-4) 
- (modified) flang/test/Lower/OpenACC/acc-atomic-read.f90 (+5-1) 
- (added) flang/test/Lower/OpenMP/atomic-implicit-cast.f90 (+121) 
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (-31) 


``````````diff
diff --git a/flang/include/flang/Lower/DirectivesCommon.h b/flang/include/flang/Lower/DirectivesCommon.h
index 6e24343cebd3a..6c20df1897fd4 100644
--- a/flang/include/flang/Lower/DirectivesCommon.h
+++ b/flang/include/flang/Lower/DirectivesCommon.h
@@ -29,6 +29,7 @@
 #include "flang/Lower/PFTBuilder.h"
 #include "flang/Lower/StatementContext.h"
 #include "flang/Lower/Support/Utils.h"
+#include "flang/Optimizer/Builder/Complex.h"
 #include "flang/Optimizer/Builder/DirectivesCommon.h"
 #include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
@@ -103,6 +104,61 @@ static void processOmpAtomicTODO(mlir::Type elementType,
   }
 }
 
+/// Emits an implicit cast for atomic statements
+static void emitImplicitCast(Fortran::lower::AbstractConverter &converter,
+                             mlir::Location loc, mlir::Value &fromAddress,
+                             mlir::Value &toAddress, mlir::Type &elementType) {
+  if (fromAddress.getType() == toAddress.getType())
+    return;
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+  mlir::Value alloca = builder.create<fir::AllocaOp>(
+      loc, fir::unwrapRefType(toAddress.getType()));
+  mlir::Value loadedVal = builder.create<fir::LoadOp>(loc, fromAddress);
+  mlir::Type toType = fir::unwrapRefType(toAddress.getType());
+  mlir::Type fromType = fir::unwrapRefType(fromAddress.getType());
+  if (!fir::isa_complex(toType) && !fir::isa_complex(fromType)) {
+    loadedVal = builder.create<fir::ConvertOp>(
+        loc, fir::unwrapRefType(toAddress.getType()), loadedVal);
+    builder.create<fir::StoreOp>(loc, loadedVal, alloca);
+  } else if (!fir::isa_complex(toType) && fir::isa_complex(fromType)) {
+    loadedVal = builder.create<fir::ExtractValueOp>(
+        loc, mlir::cast<mlir::ComplexType>(fromType).getElementType(),
+        loadedVal,
+        builder.getArrayAttr(
+            builder.getIntegerAttr(builder.getIndexType(), 0)));
+    loadedVal = builder.create<fir::ConvertOp>(loc, toType, loadedVal);
+    builder.create<fir::StoreOp>(loc, loadedVal, alloca);
+  } else if (fir::isa_complex(toType) && fir::isa_complex(fromType)) {
+    mlir::Value firstComp = builder.create<fir::ExtractValueOp>(
+        loc, mlir::cast<mlir::ComplexType>(fromType).getElementType(),
+        loadedVal,
+        builder.getArrayAttr(
+            builder.getIntegerAttr(builder.getIndexType(), 0)));
+    mlir::Value secondComp = builder.create<fir::ExtractValueOp>(
+        loc, mlir::cast<mlir::ComplexType>(fromType).getElementType(),
+        loadedVal,
+        builder.getArrayAttr(
+            builder.getIntegerAttr(builder.getIndexType(), 1)));
+    firstComp = builder.create<fir::ConvertOp>(
+        loc, mlir::cast<mlir::ComplexType>(toType).getElementType(), firstComp);
+    secondComp = builder.create<fir::ConvertOp>(
+        loc, mlir::cast<mlir::ComplexType>(toType).getElementType(),
+        secondComp);
+    auto undef = builder.create<fir::UndefOp>(loc, toType);
+    mlir::Value pair1 = builder.create<fir::InsertValueOp>(
+        loc, toType, undef, firstComp,
+        builder.getArrayAttr(
+            builder.getIntegerAttr(builder.getIndexType(), 0)));
+    mlir::Value pair = builder.create<fir::InsertValueOp>(
+        loc, toType, pair1, secondComp,
+        builder.getArrayAttr(
+            builder.getIntegerAttr(builder.getIndexType(), 1)));
+    builder.create<fir::StoreOp>(loc, pair, alloca);
+  }
+  fromAddress = alloca;
+  elementType = fir::unwrapRefType(toAddress.getType());
+}
+
 /// Used to generate atomic.read operation which is created in existing
 /// location set by builder.
 template <typename AtomicListT>
@@ -386,6 +442,7 @@ void genOmpAccAtomicRead(Fortran::lower::AbstractConverter &converter,
       fir::getBase(converter.genExprAddr(fromExpr, stmtCtx));
   mlir::Value toAddress = fir::getBase(converter.genExprAddr(
       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
+  emitImplicitCast(converter, loc, fromAddress, toAddress, elementType);
   genOmpAccAtomicCaptureStatement(converter, fromAddress, toAddress,
                                   leftHandClauseList, rightHandClauseList,
                                   elementType, loc);
@@ -481,6 +538,30 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
   mlir::Type stmt2VarType =
       fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType();
 
+  // Checks helpful in constructing the `atomic.capture` region
+  bool hasSingleVariable =
+      Fortran::semantics::checkForSingleVariableOnRHS(stmt1);
+  bool hasSymMatch = Fortran::semantics::checkForSymbolMatch(stmt2);
+
+  // Implicit casts
+  mlir::Type captureStmtElemTy;
+  if (hasSingleVariable) {
+    if (hasSymMatch) {
+      // Atomic capture construct is of the form [capture-stmt, update-stmt]
+      // FIXME: Emit an implicit cast if there is a type mismatch
+    } else {
+      // Atomic capture construct is of the form [capture-stmt, write-stmt]
+      const Fortran::semantics::SomeExpr &fromExpr =
+          *Fortran::semantics::GetExpr(stmt1Expr);
+      captureStmtElemTy = converter.genType(fromExpr);
+      emitImplicitCast(converter, loc, stmt2LHSArg, stmt1LHSArg,
+                       captureStmtElemTy);
+    }
+  } else {
+    // Atomic capture construct is of the form [update-stmt, capture-stmt]
+    // FIXME: Emit an implicit cast if there is a type mismatch
+  }
+
   mlir::Operation *atomicCaptureOp = nullptr;
   if constexpr (std::is_same<AtomicListT,
                              Fortran::parser::OmpAtomicClauseList>()) {
@@ -501,8 +582,8 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
   firOpBuilder.createBlock(&(atomicCaptureOp->getRegion(0)));
   mlir::Block &block = atomicCaptureOp->getRegion(0).back();
   firOpBuilder.setInsertionPointToStart(&block);
-  if (Fortran::semantics::checkForSingleVariableOnRHS(stmt1)) {
-    if (Fortran::semantics::checkForSymbolMatch(stmt2)) {
+  if (hasSingleVariable) {
+    if (hasSymMatch) {
       // Atomic capture construct is of the form [capture-stmt, update-stmt]
       const Fortran::semantics::SomeExpr &fromExpr =
           *Fortran::semantics::GetExpr(stmt1Expr);
@@ -521,13 +602,10 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
       mlir::Value stmt2RHSArg =
           fir::getBase(converter.genExprValue(assign2.rhs, stmtCtx));
       firOpBuilder.setInsertionPointToStart(&block);
-      const Fortran::semantics::SomeExpr &fromExpr =
-          *Fortran::semantics::GetExpr(stmt1Expr);
-      mlir::Type elementType = converter.genType(fromExpr);
       genOmpAccAtomicCaptureStatement<AtomicListT>(
           converter, stmt2LHSArg, stmt1LHSArg,
           /*leftHandClauseList=*/nullptr,
-          /*rightHandClauseList=*/nullptr, elementType, loc);
+          /*rightHandClauseList=*/nullptr, captureStmtElemTy, loc);
       genOmpAccAtomicWriteStatement<AtomicListT>(
           converter, stmt2LHSArg, stmt2RHSArg,
           /*leftHandClauseList=*/nullptr,
diff --git a/flang/test/Lower/OpenACC/acc-atomic-capture.f90 b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
index 797d322ca7ef1..c074a3e8d804e 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-capture.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
@@ -142,11 +142,14 @@ end subroutine capture_with_convert_i32_to_f64
 ! CHECK: hlfir.assign %[[CST]] to %[[X_DECL]]#0 : f64, !fir.ref<f64>
 ! CHECK: %c0_i32 = arith.constant 0 : i32
 ! CHECK: hlfir.assign %c0_i32 to %[[V_DECL]]#0 : i32, !fir.ref<i32>
-! CHECK: %[[LOAD:.*]] = fir.load %[[V_DECL]]#0 : !fir.ref<i32>
-! CHECK: %[[CONV:.*]] = fir.convert %[[LOAD]] : (i32) -> f64
+! CHECK: %[[ALLOCA:.*]] = fir.alloca i32
+! CHECK: %[[LOAD:.*]] = fir.load %[[X_DECL]]#1 : !fir.ref<f64>
+! CHECK: %[[CVT:.*]] = fir.convert %[[LOAD]] : (f64) -> i32
+! CHECK: fir.store %[[CVT]] to %[[ALLOCA]] : !fir.ref<i32>
+! CHECK: %[[EXPR_CVT:.*]] = fir.convert {{.*}} : (f64) -> i32
 ! CHECK: acc.atomic.capture {
-! CHECK:   acc.atomic.read %[[V_DECL]]#1 = %[[X_DECL]]#1 : !fir.ref<i32>, !fir.ref<f64>, f64
-! CHECK:   acc.atomic.write %[[X_DECL]]#1 = %[[CONV]] : !fir.ref<f64>, f64
+! CHECK:   acc.atomic.read %[[V_DECL]]#1 = %[[ALLOCA]] : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:   acc.atomic.write %[[ALLOCA]] = %[[EXPR_CVT]] : !fir.ref<i32>, i32 
 ! CHECK: }
 
 subroutine capture_with_convert_f64_to_i32()
diff --git a/flang/test/Lower/OpenACC/acc-atomic-read.f90 b/flang/test/Lower/OpenACC/acc-atomic-read.f90
index f2cbe6e45596a..d7f33b9e0a9ef 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-read.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-read.f90
@@ -55,4 +55,8 @@ subroutine atomic_read_with_cast()
 ! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {uniq_name = "_QFatomic_read_with_castEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 ! CHECK: %[[Y:.*]] = fir.alloca i64 {bindc_name = "y", uniq_name = "_QFatomic_read_with_castEy"}
 ! CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {uniq_name = "_QFatomic_read_with_castEy"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
-! CHECK: acc.atomic.read %[[Y_DECL]]#1 = %[[X_DECL]]#1 : !fir.ref<i64>, !fir.ref<i32>, i32
+! CHECK: %[[ALLOCA:.*]] = fir.alloca i64
+! CHECK: %[[LOAD:.*]] = fir.load %[[X_DECL]]#1 : !fir.ref<i32>
+! CHECK: %[[CVT:.*]] = fir.convert %[[LOAD]] : (i32) -> i64
+! CHECK: fir.store %[[CVT]] to %[[ALLOCA]] : !fir.ref<i64>
+! CHECK: acc.atomic.read %[[Y_DECL]]#1 = %[[ALLOCA]] : !fir.ref<i64>, !fir.ref<i64>, i64
diff --git a/flang/test/Lower/OpenMP/atomic-implicit-cast.f90 b/flang/test/Lower/OpenMP/atomic-implicit-cast.f90
new file mode 100644
index 0000000000000..2267e3b62ec57
--- /dev/null
+++ b/flang/test/Lower/OpenMP/atomic-implicit-cast.f90
@@ -0,0 +1,121 @@
+! REQUIRES: openmp_runtime
+
+! RUN: %flang_fc1 -emit-hlfir %openmp_flags %s -o - | FileCheck %s
+
+! CHECK: func.func @_QPatomic_implicit_cast_read() {
+subroutine atomic_implicit_cast_read
+! CHECK: %[[VAL_M:.*]] = fir.alloca complex<f64> {bindc_name = "m", uniq_name = "_QFatomic_implicit_cast_readEm"}
+! CHECK: %[[VAL_M_DECLARE:.*]]:2 = hlfir.declare %[[VAL_M]] {uniq_name = "_QFatomic_implicit_cast_readEm"} : (!fir.ref<complex<f64>>) -> (!fir.ref<complex<f64>>, !fir.ref<complex<f64>>)
+! CHECK: %[[VAL_W:.*]] = fir.alloca complex<f32> {bindc_name = "w", uniq_name = "_QFatomic_implicit_cast_readEw"}
+! CHECK: %[[VAL_W_DECLARE:.*]]:2 = hlfir.declare %[[VAL_W]] {uniq_name = "_QFatomic_implicit_cast_readEw"} : (!fir.ref<complex<f32>>) -> (!fir.ref<complex<f32>>, !fir.ref<complex<f32>>)
+! CHECK: %[[VAL_X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFatomic_implicit_cast_readEx"}
+! CHECK: %[[VAL_X_DECLARE:.*]]:2 = hlfir.declare %[[VAL_X]] {uniq_name = "_QFatomic_implicit_cast_readEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[VAL_Y:.*]] = fir.alloca f32 {bindc_name = "y", uniq_name = "_QFatomic_implicit_cast_readEy"}
+! CHECK: %[[VAL_Y_DECLARE:.*]]:2 = hlfir.declare %[[VAL_Y]] {uniq_name = "_QFatomic_implicit_cast_readEy"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK: %[[VAL_Z:.*]] = fir.alloca f64 {bindc_name = "z", uniq_name = "_QFatomic_implicit_cast_readEz"}
+! CHECK: %[[VAL_Z_DECLARE:.*]]:2 = hlfir.declare %[[VAL_Z]] {uniq_name = "_QFatomic_implicit_cast_readEz"} : (!fir.ref<f64>) -> (!fir.ref<f64>, !fir.ref<f64>)
+    integer :: x
+    real    :: y
+    double precision :: z
+    complex :: w
+    complex(8) :: m
+
+    ! Atomic read
+
+! CHECK: %[[ALLOCA:.*]] = fir.alloca i32
+! CHECK: %[[LOAD:.*]] = fir.load %[[VAL_Y_DECLARE]]#1 : !fir.ref<f32>
+! CHECK: %[[CVT:.*]] = fir.convert %[[LOAD]] : (f32) -> i32
+! CHECK: fir.store %[[CVT]] to %[[ALLOCA]] : !fir.ref<i32>
+! CHECK: omp.atomic.read %[[VAL_X_DECLARE]]#1 = %[[ALLOCA]] : !fir.ref<i32>, !fir.ref<i32>, i32
+    !$omp atomic read
+       x = y
+
+! CHECK: %[[ALLOCA:.*]] = fir.alloca f64
+! CHECK: %[[LOAD:.*]] = fir.load %[[VAL_X_DECLARE]]#1 : !fir.ref<i32>
+! CHECK: %[[CVT:.*]] = fir.convert %[[LOAD]] : (i32) -> f64
+! CHECK: fir.store %[[CVT]] to %[[ALLOCA]] : !fir.ref<f64>
+! CHECK: omp.atomic.read %[[VAL_Z_DECLARE]]#1 = %[[ALLOCA]] : !fir.ref<f64>, !fir.ref<f64>, f64
+    !$omp atomic read
+       z = x
+
+! CHECK: %[[ALLOCA:.*]] = fir.alloca i32
+! CHECK: %[[LOAD:.*]] = fir.load %[[VAL_W_DECLARE]]#1 : !fir.ref<complex<f32>>
+! CHECK: %[[EXT:.*]] = fir.extract_value %[[LOAD]], [0 : index] : (complex<f32>) -> f32
+! CHECK: %[[CVT:.*]] = fir.convert %[[EXT]] : (f32) -> i32
+! CHECK: fir.store %[[CVT]] to %[[ALLOCA]] : !fir.ref<i32>
+! CHECK: omp.atomic.read %[[VAL_X_DECLARE]]#1 = %[[ALLOCA]] : !fir.ref<i32>, !fir.ref<i32>, i32
+    !$omp atomic read
+       x = w
+
+! CHECK: %[[ALLOCA:.*]] = fir.alloca f32
+! CHECK: %[[LOAD:.*]] = fir.load %[[VAL_W_DECLARE]]#1 : !fir.ref<complex<f32>>
+! CHECK: %[[EXT:.*]] = fir.extract_value %[[LOAD]], [0 : index] : (complex<f32>) -> f32
+! CHECK: %[[CVT:.*]] = fir.convert %[[EXT]] : (f32) -> f32
+! CHECK: fir.store %[[CVT]] to %[[ALLOCA]] : !fir.ref<f32>
+! CHECK: omp.atomic.read %[[VAL_Y_DECLARE]]#1 = %[[ALLOCA]] : !fir.ref<f32>, !fir.ref<f32>, f32
+    !$omp atomic read
+       y = w
+
+! CHECK: %[[ALLOCA:.*]] = fir.alloca complex<f64>
+! CHECK: %[[LOAD:.*]] = fir.load %[[VAL_W_DECLARE]]#1 : !fir.ref<complex<f32>>
+! CHECK: %[[EXT0:.*]] = fir.extract_value %[[LOAD]], [0 : index] : (complex<f32>) -> f32
+! CHECK: %[[EXT1:.*]] = fir.extract_value %[[LOAD]], [1 : index] : (complex<f32>) -> f32
+! CHECK: %[[CVT0:.*]] = fir.convert %[[EXT0]] : (f32) -> f64
+! CHECK: %[[CVT1:.*]] = fir.convert %[[EXT1]] : (f32) -> f64
+! CHECK: %[[UNDEF:.*]] = fir.undefined complex<f64>
+! CHECK: %[[INSERT1:.*]] = fir.insert_value %[[UNDEF]], %[[CVT0]], [0 : index] : (complex<f64>, f64) -> complex<f64>
+! CHECK: %[[INSERT2:.*]] = fir.insert_value %[[INSERT1]], %[[CVT1]], [1 : index] : (complex<f64>, f64) -> complex<f64>
+! CHECK: fir.store %[[INSERT2]] to %[[ALLOCA]] : !fir.ref<complex<f64>>
+! CHECK: omp.atomic.read %[[VAL_M_DECLARE]]#1 = %[[ALLOCA]] : !fir.ref<complex<f64>>, !fir.ref<complex<f64>>, complex<f64>
+    !$omp atomic read
+       m = w
+end subroutine
+! CHECK: func.func @_QPatomic_implicit_cast_write()
+subroutine atomic_implicit_cast_write
+! CHECK: %[[VAL_M:.*]] = fir.alloca complex<f64> {bindc_name = "m", uniq_name = "_QFatomic_implicit_cast_writeEm"}
+! CHECK: %[[VAL_M_DECLARE:.*]]:2 = hlfir.declare %[[VAL_M]] {uniq_name = "_QFatomic_implicit_cast_writeEm"} : (!fir.ref<complex<f64>>) -> (!fir.ref<complex<f64>>, !fir.ref<complex<f64>>)
+! CHECK: %[[VAL_W:.*]] = fir.alloca complex<f32> {bindc_name = "w", uniq_name = "_QFatomic_implicit_cast_writeEw"}
+! CHECK: %[[VAL_W_DECLARE:.*]]:2 = hlfir.declare %[[VAL_W]] {uniq_name = "_QFatomic_implicit_cast_writeEw"} : (!fir.ref<complex<f32>>) -> (!fir.ref<complex<f32>>, !fir.ref<complex<f32>>)
+! CHECK: %[[VAL_X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFatomic_implicit_cast_writeEx"}
+! CHECK: %[[VAL_X_DECLARE:.*]]:2 = hlfir.declare %[[VAL_X]] {uniq_name = "_QFatomic_implicit_cast_writeEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[VAL_Y:.*]] = fir.alloca f32 {bindc_name = "y", uniq_name = "_QFatomic_implicit_cast_writeEy"}
+! CHECK: %[[VAL_Y_DECLARE:.*]]:2 = hlfir.declare %[[VAL_Y]] {uniq_name = "_QFatomic_implicit_cast_writeEy"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK: %[[VAL_Z:.*]] = fir.alloca f64 {bindc_name = "z", uniq_name = "_QFatomic_implicit_cast_writeEz"}
+! CHECK: %[[VAL_Z_DECLARE:.*]]:2 = hlfir.declare %[[VAL_Z]] {uniq_name = "_QFatomic_implicit_cast_writeEz"} : (!fir.ref<f64>) -> (!fir.ref<f64>, !fir.ref<f64>)
+    integer :: x
+    real    :: y
+    double precision :: z
+    complex :: w
+    complex(8) :: m
+ 
+! CHECK: %[[LOAD:.*]] = fir.load %[[VAL_Y_DECLARE]]#0 : !fir.ref<f32>
+! CHECK: %[[CVT:.*]] = fir.convert %[[LOAD]] : (f32) -> i32
+! CHECK: omp.atomic.write %[[VAL_X_DECLARE]]#1 = %[[CVT]] : !fir.ref<i32>, i32
+    !$omp atomic write
+       x = y
+
+! CHECK: %[[LOAD:.*]] = fir.load %[[VAL_X_DECLARE]]#0 : !fir.ref<i32>
+! CHECK: %[[CVT:.*]] = fir.convert %[[LOAD]] : (i32) -> f64
+! CHECK: omp.atomic.write %[[VAL_Z_DECLARE:.*]] = %[[CVT]] : !fir.ref<f64>, f64
+    !$omp atomic write
+       z = x
+
+! CHECK: %[[LOAD:.*]] = fir.load %[[VAL_W_DECLARE]]#0 : !fir.ref<complex<f32>>
+! CHECK: %[[EXT:.*]] = fir.extract_value %[[LOAD]], [0 : index] : (complex<f32>) -> f32
+! CHECK: %[[CVT:.*]] = fir.convert %[[EXT]] : (f32) -> i32
+! CHECK: omp.atomic.write %[[VAL_X_DECLARE]]#1 = %[[CVT]] : !fir.ref<i32>, i32
+    !$omp atomic write
+       x = w
+
+! CHECK: %[[LOAD:.*]] = fir.load %[[VAL_W_DECLARE]]#0 : !fir.ref<complex<f32>>
+! CHECK: %[[EXT:.*]] = fir.extract_value %[[LOAD]], [0 : index] : (complex<f32>) -> f32
+! CHECK: omp.atomic.write %[[VAL_Y_DECLARE]]#1 = %[[EXT]] : !fir.ref<f32>, f32
+    !$omp atomic write
+       y = w 
+ 
+! CHECK: %[[LOAD:.*]] = fir.load %[[VAL_W_DECLARE]]#0 : !fir.ref<complex<f32>>
+! CHECK: %[[CVT:.*]] = fir.convert %[[LOAD]] : (complex<f32>) -> complex<f64>
+! CHECK: omp.atomic.write %[[VAL_M_DECLARE]]#1 = %[[CVT]] : !fir.ref<complex<f64>>, complex<f64>
+    !$omp atomic write
+       m = w
+end subroutine
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index e34e93442ff85..5e07668c43396 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -268,33 +268,6 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
   return Result;
 }
 
-/// Emit an implicit cast to convert \p XRead to type of variable \p V
-static llvm::Value *emitImplicitCast(IRBuilder<> &Builder, llvm::Value *XRead,
-                                     llvm::Value *V) {
-  // TODO: Add this functionality to the `AtomicInfo` interface
-  llvm::Type *XReadType = XRead->getType();
-  llvm::Type *VType = V->getType();
-  if (llvm::AllocaInst *vAlloca = dyn_cast<llvm::AllocaInst>(V))
-    VType = vAlloca->getAllocatedType();
-
-  if (XReadType->isStructTy() && VType->isStructTy())
-    // No need to extract or convert. A direct
-    // `store` will suffice.
-    return XRead;
-
-  if (XReadType->isStructTy())
-    XRead = Builder.CreateExtractValue(XRead, /*Idxs=*/0);
-  if (VType->isIntegerTy() && XReadType->isFloatingPointTy())
-    XRead = Builder.CreateFPToSI(XRead, VType);
-  else if (VType->isFloatingPointTy() && XReadType->isIntegerTy())
-    XRead = Builder.CreateSIToFP(XRead, VType);
-  else if (VType->isIntegerTy() && XReadType->isIntegerTy())
-    XRead = Builder.CreateIntCast(XRead, VType, true);
-  else if (VType->isFloatingPointTy() && XReadType->isFloatingPointTy())
-    XRead = Builder.CreateFPCast(XRead, VType);
-  return XRead;
-}
-
 /// Make \p Source branch to \p Target.
 ///
 /// Handles two situations:
@@ -8655,8 +8628,6 @@ OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
     }
   }
   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Read);
-  if (XRead->getType() != V.Var->getType())
-    XRead = emitImplicitCast(Builder, XRead, V.Var);
   Builder.CreateStore(XRead, V.Var, V.IsVolatile);
   return Builder.saveIP();
 }
@@ -8941,8 +8912,6 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicCapture(
     return AtomicResult.takeError();
   Value *CapturedVal =
       (IsPostfixUpdate ? AtomicResult->first : AtomicResult->second);
-  if (CapturedVal->getType() != V.Var->getType())
-    CapturedVal = emitImplicitCast(Builder, CapturedVal, V.Var);
   Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile);
 
   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Capture);

``````````

</details>


https://github.com/llvm/llvm-project/pull/131603


More information about the llvm-commits mailing list