[Mlir-commits] [mlir] [mlir] fix MemRefToLLVM lowering of atomic operations (PR #139045)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Thu May 8 01:15:40 PDT 2025
https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/139045
We have been confusingly, and arguably incorrectly, lowering `m**imumf` atomic RMW operations in the MemRef dialect to `fm**` atomic RMW operations in the LLVM dialect, which have different NaN-propagation semantics: `m**imumf` propagates NaNs from either operand whereas `fm**`, which lowers to the `fm**num` intrinsic returns the non-NaN operand. This also contradicts the lowering of `arith.m**imumf` and `arith.m**numf` operations.
Change the lowering to match the terminology in arith.
Add tests for these lowerings.
Keep a debug message in case of surprising behavior downstream (the code may be producing more NaNs now).
>From cc9dfde0466a2ac7cdb23520200d62e5d6532a92 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <git at ozinenko.com>
Date: Thu, 8 May 2025 10:03:37 +0200
Subject: [PATCH] [mlir] fix MemRefToLLVM lowering of atomic operations
We have been confusingly, and arguably incorrectly, lowering `m**imumf` atomic
RMW operations in the MemRef dialect to `fm**` atomic RMW operations in the
LLVM dialect, which have different NaN-propagation semantics: `m**imumf`
propagates NaNs from either operand whereas `fm**`, which lowers to the
`fm**num` intrinsic returns the non-NaN operand. This also contradicts the
lowering of `arith.m**imumf` and `arith.m**numf` operations.
Change the lowering to match the terminology in arith.
Add tests for these lowerings.
Keep a debug message in case of surprising behavior downstream (the code may be
producing more NaNs now).
---
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 13 +++++++++++++
.../Conversion/MemRefToLLVM/memref-to-llvm.mlir | 10 +++++++++-
2 files changed, 22 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index c8b2c0bdc6c20..0eade14ee89e5 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -28,6 +28,9 @@
#include "llvm/Support/MathExtras.h"
#include <optional>
+#define DEBUG_TYPE "memref-to-llvm"
+#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] "
+
namespace mlir {
#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
@@ -1773,12 +1776,22 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
case arith::AtomicRMWKind::assign:
return LLVM::AtomicBinOp::xchg;
case arith::AtomicRMWKind::maximumf:
+ // TODO: remove this by end of 2025.
+ LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed "
+ "from fmax to fmaximum, expect more NaNs");
+ return LLVM::AtomicBinOp::fmaximum;
+ case arith::AtomicRMWKind::maxnumf:
return LLVM::AtomicBinOp::fmax;
case arith::AtomicRMWKind::maxs:
return LLVM::AtomicBinOp::max;
case arith::AtomicRMWKind::maxu:
return LLVM::AtomicBinOp::umax;
case arith::AtomicRMWKind::minimumf:
+ // TODO: remove this by end of 2025.
+ LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed "
+ "from fmin to fminimum, expect more NaNs");
+ return LLVM::AtomicBinOp::fminimum;
+ case arith::AtomicRMWKind::minnumf:
return LLVM::AtomicBinOp::fmin;
case arith::AtomicRMWKind::mins:
return LLVM::AtomicBinOp::min;
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 5538ddf8e4c3c..a986a39fc1e92 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -427,11 +427,19 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
// CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw addf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw fmaximum %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw maxnumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw fmax %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw minimumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw fminimum %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw minnumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw fmin %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw ori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
- // CHECK-INTERFACE-COUNT-9: llvm.atomicrmw
+ // CHECK-INTERFACE-COUNT-13: llvm.atomicrmw
return
}
More information about the Mlir-commits
mailing list