[Mlir-commits] [mlir] [mlir] fix MemRefToLLVM lowering of atomic operations (PR #139045)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 8 01:16:20 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Oleksandr "Alex" Zinenko (ftynse)
<details>
<summary>Changes</summary>
We have been confusingly, and arguably incorrectly, lowering `m**imumf` atomic RMW operations in the MemRef dialect to `fm**` atomic RMW operations in the LLVM dialect, which have different NaN-propagation semantics: `m**imumf` propagates NaNs from either operand whereas `fm**`, which lowers to the `fm**num` intrinsic returns the non-NaN operand. This also contradicts the lowering of `arith.m**imumf` and `arith.m**numf` operations.
Change the lowering to match the terminology in arith.
Add tests for these lowerings.
Keep a debug message in case of surprising behavior downstream (the code may be producing more NaNs now).
---
Full diff: https://github.com/llvm/llvm-project/pull/139045.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+13)
- (modified) mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir (+9-1)
``````````diff
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index c8b2c0bdc6c20..0eade14ee89e5 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -28,6 +28,9 @@
#include "llvm/Support/MathExtras.h"
#include <optional>
+#define DEBUG_TYPE "memref-to-llvm"
+#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] "
+
namespace mlir {
#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
@@ -1773,12 +1776,22 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
case arith::AtomicRMWKind::assign:
return LLVM::AtomicBinOp::xchg;
case arith::AtomicRMWKind::maximumf:
+ // TODO: remove this by end of 2025.
+ LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed "
+ "from fmax to fmaximum, expect more NaNs");
+ return LLVM::AtomicBinOp::fmaximum;
+ case arith::AtomicRMWKind::maxnumf:
return LLVM::AtomicBinOp::fmax;
case arith::AtomicRMWKind::maxs:
return LLVM::AtomicBinOp::max;
case arith::AtomicRMWKind::maxu:
return LLVM::AtomicBinOp::umax;
case arith::AtomicRMWKind::minimumf:
+ // TODO: remove this by end of 2025.
+ LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed "
+ "from fmin to fminimum, expect more NaNs");
+ return LLVM::AtomicBinOp::fminimum;
+ case arith::AtomicRMWKind::minnumf:
return LLVM::AtomicBinOp::fmin;
case arith::AtomicRMWKind::mins:
return LLVM::AtomicBinOp::min;
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 5538ddf8e4c3c..a986a39fc1e92 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -427,11 +427,19 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
// CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw addf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw fmaximum %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw maxnumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw fmax %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw minimumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw fminimum %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw minnumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw fmin %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw ori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
- // CHECK-INTERFACE-COUNT-9: llvm.atomicrmw
+ // CHECK-INTERFACE-COUNT-13: llvm.atomicrmw
return
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/139045
More information about the Mlir-commits
mailing list