[Mlir-commits] [mlir] [MLIR][MemRefToLLVM] Diagnose rank/index mismatch in generic_atomic_rmw lowering (PR #178704)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 29 09:28:30 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Ayush Kumar Gaur (Ayush3941)
<details>
<summary>Changes</summary>
### What the problem:
memref.generic_atomic_rmw lowering crashes when the number of indices does not match the memref rank, triggering an out-of-bounds access in pointer computation.
### Why it happened:
The MemRefToLLVM lowering assumes index count equals memref rank and does not validate this, allowing invalid IR to reach getStridedElementPtr and assert.
### Whats the Fix:
Add an explicit rank/index count check in GenericAtomicRMWOpLowering and emit a diagnostic instead of crashing.
Fixes #<!-- -->178211
---
Full diff: https://github.com/llvm/llvm-project/pull/178704.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+5)
- (modified) mlir/test/Conversion/MemRefToLLVM/invalid.mlir (+11)
``````````diff
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 91a0c4b55fa84..88d212b2ebfb3 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -722,6 +722,11 @@ struct GenericAtomicRMWOpLowering
// Compute the loaded value and branch to the loop block.
rewriter.setInsertionPointToEnd(initBlock);
auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
+ if (adaptor.getIndices().size() != (size_t)memRefType.getRank()) {
+ return atomicOp.emitError()
+ << "index count (" << adaptor.getIndices().size()
+ << ") does not match memref rank (" << memRefType.getRank() << ")";
+ }
auto dataPtr = getStridedElementPtr(
rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
Value init = LLVM::LoadOp::create(
diff --git a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
index 5462d3278d9e6..dc44d3712344b 100644
--- a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
@@ -52,3 +52,14 @@ func.func @test_atomic_exch(%arg0: memref<?xi32>, %idx: index, %value: i32) {
}
func.return
}
+
+// -----
+
+func.func @generic_atomic_rmw_rank_mismatch(%arg0: memref<i32>, %idx: index) {
+ // expected-error at +1 {{index count (1) does not match memref rank (0)}}
+ %r = memref.generic_atomic_rmw %arg0[%idx] : memref<i32> {
+ ^bb0(%v: i32):
+ memref.atomic_yield %v : i32
+ }
+ func.return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/178704
More information about the Mlir-commits
mailing list