[Mlir-commits] [mlir] [mlir] fix MemRefToLLVM lowering of atomic operations (PR #139045)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 8 01:16:20 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

<details>
<summary>Changes</summary>

We have been confusingly, and arguably incorrectly, lowering `m**imumf` atomic RMW operations in the MemRef dialect to `fm**` atomic RMW operations in the LLVM dialect, which have different NaN-propagation semantics: `m**imumf` propagates NaNs from either operand whereas `fm**`, which lowers to the `fm**num` intrinsic returns the non-NaN operand. This also contradicts the lowering of `arith.m**imumf` and `arith.m**numf` operations.

Change the lowering to match the terminology in arith.

Add tests for these lowerings.

Keep a debug message in case of surprising behavior downstream (the code may be producing more NaNs now).

---
Full diff: https://github.com/llvm/llvm-project/pull/139045.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+13) 
- (modified) mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir (+9-1) 


``````````diff
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index c8b2c0bdc6c20..0eade14ee89e5 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -28,6 +28,9 @@
 #include "llvm/Support/MathExtras.h"
 #include <optional>
 
+#define DEBUG_TYPE "memref-to-llvm"
+#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] "
+
 namespace mlir {
 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
 #include "mlir/Conversion/Passes.h.inc"
@@ -1773,12 +1776,22 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
   case arith::AtomicRMWKind::assign:
     return LLVM::AtomicBinOp::xchg;
   case arith::AtomicRMWKind::maximumf:
+    // TODO: remove this by end of 2025.
+    LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed "
+                         "from fmax to fmaximum, expect more NaNs");
+    return LLVM::AtomicBinOp::fmaximum;
+  case arith::AtomicRMWKind::maxnumf:
     return LLVM::AtomicBinOp::fmax;
   case arith::AtomicRMWKind::maxs:
     return LLVM::AtomicBinOp::max;
   case arith::AtomicRMWKind::maxu:
     return LLVM::AtomicBinOp::umax;
   case arith::AtomicRMWKind::minimumf:
+    // TODO: remove this by end of 2025.
+    LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed "
+                         "from fmin to fminimum, expect more NaNs");
+    return LLVM::AtomicBinOp::fminimum;
+  case arith::AtomicRMWKind::minnumf:
     return LLVM::AtomicBinOp::fmin;
   case arith::AtomicRMWKind::mins:
     return LLVM::AtomicBinOp::min;
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 5538ddf8e4c3c..a986a39fc1e92 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -427,11 +427,19 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
   // CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel
   memref.atomic_rmw addf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
   // CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel
+  memref.atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+  // CHECK: llvm.atomicrmw fmaximum %{{.*}}, %{{.*}} acq_rel
+  memref.atomic_rmw maxnumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+  // CHECK: llvm.atomicrmw fmax %{{.*}}, %{{.*}} acq_rel
+  memref.atomic_rmw minimumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+  // CHECK: llvm.atomicrmw fminimum %{{.*}}, %{{.*}} acq_rel
+  memref.atomic_rmw minnumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+  // CHECK: llvm.atomicrmw fmin %{{.*}}, %{{.*}} acq_rel
   memref.atomic_rmw ori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
   // CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
   memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
   // CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
-  // CHECK-INTERFACE-COUNT-9: llvm.atomicrmw
+  // CHECK-INTERFACE-COUNT-13: llvm.atomicrmw
   return
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/139045


More information about the Mlir-commits mailing list