[llvm] [mlir] [llvm] Add implicit cast to omp.atomic.read (PR #114659)

via llvm-commits llvm-commits at lists.llvm.org
Sat Nov 2 03:14:24 PDT 2024


https://github.com/NimishMishra created https://github.com/llvm/llvm-project/pull/114659

Should the operands of `omp.atomic.read` differ, emit an implicit cast. In case of `struct` arguments, extract the 0-th index, emit an implicit cast if required, and store at the destination.

Fixes https://github.com/llvm/llvm-project/issues/112908

>From 08c34da4e77c3b1a01840c95846a034a7a432677 Mon Sep 17 00:00:00 2001
From: Nimish Mishra <neelam.nimish at gmail.com>
Date: Sat, 2 Nov 2024 15:40:54 +0530
Subject: [PATCH] [llvm] Add implicit cast to omp.atomic.read

---
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 30 ++++++++++
 mlir/test/Target/LLVMIR/openmp-llvm.mlir  | 71 +++++++++++++++++++++++
 2 files changed, 101 insertions(+)

diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index d2e4dc1c85dfd2..9e69b730884ff6 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -265,6 +265,32 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
   return Result;
 }
 
+/// Emit an implicit cast to convert \p XRead to type of variable \p V
+static llvm::Value *emitImplicitCast(IRBuilder<> &Builder, llvm::Value *XRead,
+                                     llvm::Value *V) {
+  llvm::Type *XReadType = XRead->getType();
+  llvm::Type *VType = V->getType();
+  if (llvm::AllocaInst *vAlloca = dyn_cast<llvm::AllocaInst>(V))
+    VType = vAlloca->getAllocatedType();
+
+  if (XReadType->isStructTy() && VType->isStructTy())
+    // No need to extract or convert. A direct
+    // `store` will suffice.
+    return XRead;
+
+  if (XRead->getType()->isStructTy())
+    XRead = Builder.CreateExtractValue(XRead, /*Idxs=*/0);
+  if (VType->isIntegerTy() && XReadType->isFloatingPointTy())
+    XRead = Builder.CreateFPToSI(XRead, VType);
+  else if (VType->isFloatingPointTy() && XReadType->isIntegerTy())
+    XRead = Builder.CreateSIToFP(XRead, VType);
+  else if (VType->isIntegerTy() && XReadType->isIntegerTy())
+    XRead = Builder.CreateIntCast(XRead, VType, true);
+  else if (VType->isFloatingPointTy() && XReadType->isFloatingPointTy())
+    XRead = Builder.CreateFPCast(XRead, VType);
+  return XRead;
+}
+
 /// Make \p Source branch to \p Target.
 ///
 /// Handles two situations:
@@ -8076,6 +8102,8 @@ OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
     }
   }
   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Read);
+  if (XRead->getType() != V.Var->getType())
+    XRead = emitImplicitCast(Builder, XRead, V.Var);
   Builder.CreateStore(XRead, V.Var, V.IsVolatile);
   return Builder.saveIP();
 }
@@ -8360,6 +8388,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicCapture(
     return AtomicResult.takeError();
   Value *CapturedVal =
       (IsPostfixUpdate ? AtomicResult->first : AtomicResult->second);
+  if (CapturedVal->getType() != V.Var->getType())
+    CapturedVal = emitImplicitCast(Builder, CapturedVal, V.Var);
   Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile);
 
   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Capture);
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 49f9f3562c78b5..68dd9a66652b1c 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -1368,6 +1368,77 @@ llvm.func @omp_atomic_read(%arg0 : !llvm.ptr, %arg1 : !llvm.ptr) -> () {
 
 // -----
 
+// CHECK-LABEL: @omp_atomic_read_implicit_cast
+llvm.func @omp_atomic_read_implicit_cast () {
+//CHECK: %[[Z:.*]] = alloca float, i64 1, align 4
+//CHECK: %[[Y:.*]] = alloca double, i64 1, align 8
+//CHECK: %[[X:.*]] = alloca [2 x { float, float }], i64 1, align 8
+//CHECK: %[[W:.*]] = alloca i32, i64 1, align 4
+//CHECK: %[[X_ELEMENT:.*]] = getelementptr { float, float }, ptr %3, i64 0
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x f32 {bindc_name = "z"} : (i64) -> !llvm.ptr
+  %2 = llvm.mlir.constant(1 : i64) : i64
+  %3 = llvm.alloca %2 x f64 {bindc_name = "y"} : (i64) -> !llvm.ptr
+  %4 = llvm.mlir.constant(1 : i64) : i64
+  %5 = llvm.alloca %4 x !llvm.array<2 x struct<(f32, f32)>> {bindc_name = "x"} : (i64) -> !llvm.ptr
+  %6 = llvm.mlir.constant(1 : i64) : i64
+  %7 = llvm.alloca %6 x i32 {bindc_name = "w"} : (i64) -> !llvm.ptr
+  %8 = llvm.mlir.constant(1 : index) : i64
+  %9 = llvm.mlir.constant(2 : index) : i64
+  %10 = llvm.mlir.constant(1 : i64) : i64
+  %11 = llvm.mlir.constant(0 : i64) : i64
+  %12 = llvm.sub %8, %10 overflow<nsw> : i64
+  %13 = llvm.mul %12, %10 overflow<nsw> : i64
+  %14 = llvm.mul %13, %10 overflow<nsw> : i64
+  %15 = llvm.add %14, %11 overflow<nsw> : i64
+  %16 = llvm.mul %10, %9 overflow<nsw> : i64
+  %17 = llvm.getelementptr %5[%15] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(f32, f32)>
+
+//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = alloca { float, float }, align 8
+//CHECK: call void @__atomic_load(i64 8, ptr %[[X_ELEMENT]], ptr %[[ATOMIC_LOAD_TEMP]], i32 0)
+//CHECK: %[[LOAD:.*]] = load { float, float }, ptr %[[ATOMIC_LOAD_TEMP]], align 8
+//CHECK: %[[EXT:.*]] = extractvalue { float, float } %[[LOAD]], 0
+//CHECK: store float %[[EXT]], ptr %[[Y]], align 4
+  omp.atomic.read %3 = %17 : !llvm.ptr, !llvm.struct<(f32, f32)>
+
+//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[Z]] monotonic, align 4
+//CHECK: %[[CAST:.*]] = bitcast i32 %[[ATOMIC_LOAD_TEMP]] to float
+//CHECK: %[[LOAD:.*]] = fpext float %[[CAST]] to double
+//CHECK: store double %[[LOAD]], ptr %[[Y]], align 8
+  omp.atomic.read %3 = %1 : !llvm.ptr, f32
+
+//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[W]] monotonic, align 4
+//CHECK: %[[LOAD:.*]] = sitofp i32 %[[ATOMIC_LOAD_TEMP]] to double
+//CHECK: store double %[[LOAD]], ptr %[[Y]], align 8
+  omp.atomic.read %3 = %7 : !llvm.ptr, i32
+
+//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i64, ptr %[[Y]] monotonic, align 4
+//CHECK: %[[CAST:.*]] = bitcast i64 %[[ATOMIC_LOAD_TEMP]] to double
+//CHECK: %[[LOAD:.*]] = fptrunc double %[[CAST]] to float
+//CHECK: store float %[[LOAD]], ptr %[[Z]], align 4
+  omp.atomic.read %1 = %3 : !llvm.ptr, f64
+
+//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[W]] monotonic, align 4
+//CHECK: %[[LOAD:.*]] = sitofp i32 %[[ATOMIC_LOAD_TEMP]] to float
+//CHECK: store float %[[LOAD]], ptr %[[Z]], align 4
+  omp.atomic.read %1 = %7 : !llvm.ptr, i32
+
+//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i64, ptr %[[Y]] monotonic, align 4
+//CHECK: %[[CAST:.*]] = bitcast i64 %[[ATOMIC_LOAD_TEMP]] to double
+//CHECK: %[[LOAD:.*]] = fptosi double %[[CAST]] to i32
+//CHECK: store i32 %[[LOAD]], ptr %[[W]], align 4
+  omp.atomic.read %7 = %3 : !llvm.ptr, f64
+
+//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[Z]] monotonic, align 4
+//CHECK: %[[CAST:.*]] = bitcast i32 %[[ATOMIC_LOAD_TEMP]] to float
+//CHECK: %[[LOAD:.*]] = fptosi float %[[CAST]] to i32
+//CHECK: store i32 %[[LOAD]], ptr %[[W]], align 4
+  omp.atomic.read %7 = %1 : !llvm.ptr, f32
+  llvm.return
+}
+
+// -----
+
 // CHECK-LABEL: @omp_atomic_write
 // CHECK-SAME: (ptr %[[x:.*]], i32 %[[expr:.*]])
 llvm.func @omp_atomic_write(%x: !llvm.ptr, %expr: i32) -> () {



More information about the llvm-commits mailing list