[Mlir-commits] [mlir] [MLIR][MemRefToLLVM] Diagnose rank/index mismatch in generic_atomic_rmw lowering (PR #178704)
Ayush Kumar Gaur
llvmlistbot at llvm.org
Thu Jan 29 09:27:56 PST 2026
https://github.com/Ayush3941 created https://github.com/llvm/llvm-project/pull/178704
### What the problem:
memref.generic_atomic_rmw lowering crashes when the number of indices does not match the memref rank, triggering an out-of-bounds access in pointer computation.
### Why it happened:
The MemRefToLLVM lowering assumes index count equals memref rank and does not validate this, allowing invalid IR to reach getStridedElementPtr and assert.
### Whats the Fix:
Add an explicit rank/index count check in GenericAtomicRMWOpLowering and emit a diagnostic instead of crashing.
Fixes #178211
>From 2a9f124ec9f13eca7b5e15b730d0ea405f1f57f1 Mon Sep 17 00:00:00 2001
From: Ayush3941 <ayushkgaur1 at gmail.com>
Date: Thu, 29 Jan 2026 12:19:03 -0500
Subject: [PATCH 1/2] [MLIR][MemRefToLLVM] Diagnose rank/index mismatch in
generic_atomic_rmw lowering
---
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 5 +++++
mlir/test/Conversion/MemRefToLLVM/invalid.mlir | 11 +++++++++++
2 files changed, 16 insertions(+)
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 91a0c4b55fa84..9666d45e586de 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -722,6 +722,11 @@ struct GenericAtomicRMWOpLowering
// Compute the loaded value and branch to the loop block.
rewriter.setInsertionPointToEnd(initBlock);
auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
+ if (adaptor.getIndices().size() != (size_t)memRefType.getRank()) {
+ return atomicOp.emitError()
+ << "index count (" << adaptor.getIndices().size()
+ << ") does not match memref rank (" << memRefType.getRank() << ")";
+ }
auto dataPtr = getStridedElementPtr(
rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
Value init = LLVM::LoadOp::create(
diff --git a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
index 5462d3278d9e6..dc44d3712344b 100644
--- a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
@@ -52,3 +52,14 @@ func.func @test_atomic_exch(%arg0: memref<?xi32>, %idx: index, %value: i32) {
}
func.return
}
+
+// -----
+
+func.func @generic_atomic_rmw_rank_mismatch(%arg0: memref<i32>, %idx: index) {
+ // expected-error at +1 {{index count (1) does not match memref rank (0)}}
+ %r = memref.generic_atomic_rmw %arg0[%idx] : memref<i32> {
+ ^bb0(%v: i32):
+ memref.atomic_yield %v : i32
+ }
+ func.return
+}
>From 328f997d93c65b7ca1825ce0a1e4dc31fe94d278 Mon Sep 17 00:00:00 2001
From: Ayush3941 <ayushkgaur1 at gmail.com>
Date: Thu, 29 Jan 2026 12:21:14 -0500
Subject: [PATCH 2/2] [MLIR][MemRefToLLVM] Diagnose rank/index mismatch in
generic_atomic_rmw lowering with general format
---
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 9666d45e586de..88d212b2ebfb3 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -724,9 +724,9 @@ struct GenericAtomicRMWOpLowering
auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
if (adaptor.getIndices().size() != (size_t)memRefType.getRank()) {
return atomicOp.emitError()
- << "index count (" << adaptor.getIndices().size()
- << ") does not match memref rank (" << memRefType.getRank() << ")";
- }
+ << "index count (" << adaptor.getIndices().size()
+ << ") does not match memref rank (" << memRefType.getRank() << ")";
+ }
auto dataPtr = getStridedElementPtr(
rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
Value init = LLVM::LoadOp::create(
More information about the Mlir-commits
mailing list