[Mlir-commits] [mlir] [MLIR][MemRefToLLVM] Diagnose rank/index mismatch in generic_atomic_rmw lowering (PR #178704)

Ayush Kumar Gaur llvmlistbot at llvm.org
Thu Jan 29 09:27:56 PST 2026


https://github.com/Ayush3941 created https://github.com/llvm/llvm-project/pull/178704

### What the problem:
memref.generic_atomic_rmw lowering crashes when the number of indices does not match the memref rank, triggering an out-of-bounds access in pointer computation.

### Why it happened:
The MemRefToLLVM lowering assumes index count equals memref rank and does not validate this, allowing invalid IR to reach getStridedElementPtr and assert.

### Whats the Fix:
Add an explicit rank/index count check in GenericAtomicRMWOpLowering and emit a diagnostic instead of crashing.

Fixes #178211

>From 2a9f124ec9f13eca7b5e15b730d0ea405f1f57f1 Mon Sep 17 00:00:00 2001
From: Ayush3941 <ayushkgaur1 at gmail.com>
Date: Thu, 29 Jan 2026 12:19:03 -0500
Subject: [PATCH 1/2] [MLIR][MemRefToLLVM] Diagnose rank/index mismatch in
 generic_atomic_rmw lowering

---
 mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp |  5 +++++
 mlir/test/Conversion/MemRefToLLVM/invalid.mlir    | 11 +++++++++++
 2 files changed, 16 insertions(+)

diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 91a0c4b55fa84..9666d45e586de 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -722,6 +722,11 @@ struct GenericAtomicRMWOpLowering
     // Compute the loaded value and branch to the loop block.
     rewriter.setInsertionPointToEnd(initBlock);
     auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
+    if (adaptor.getIndices().size() != (size_t)memRefType.getRank()) {
+      return atomicOp.emitError()
+               << "index count (" << adaptor.getIndices().size()
+               << ") does not match memref rank (" << memRefType.getRank() << ")";
+      }
     auto dataPtr = getStridedElementPtr(
         rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
     Value init = LLVM::LoadOp::create(
diff --git a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
index 5462d3278d9e6..dc44d3712344b 100644
--- a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
@@ -52,3 +52,14 @@ func.func @test_atomic_exch(%arg0: memref<?xi32>, %idx: index, %value: i32) {
   }
   func.return
 }
+
+// -----
+
+func.func @generic_atomic_rmw_rank_mismatch(%arg0: memref<i32>, %idx: index) {
+  // expected-error at +1 {{index count (1) does not match memref rank (0)}}
+  %r = memref.generic_atomic_rmw %arg0[%idx] : memref<i32> {
+  ^bb0(%v: i32):
+    memref.atomic_yield %v : i32
+  }
+  func.return
+}

>From 328f997d93c65b7ca1825ce0a1e4dc31fe94d278 Mon Sep 17 00:00:00 2001
From: Ayush3941 <ayushkgaur1 at gmail.com>
Date: Thu, 29 Jan 2026 12:21:14 -0500
Subject: [PATCH 2/2] [MLIR][MemRefToLLVM] Diagnose rank/index mismatch in
 generic_atomic_rmw lowering with general format

---
 mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 9666d45e586de..88d212b2ebfb3 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -724,9 +724,9 @@ struct GenericAtomicRMWOpLowering
     auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
     if (adaptor.getIndices().size() != (size_t)memRefType.getRank()) {
       return atomicOp.emitError()
-               << "index count (" << adaptor.getIndices().size()
-               << ") does not match memref rank (" << memRefType.getRank() << ")";
-      }
+             << "index count (" << adaptor.getIndices().size()
+             << ") does not match memref rank (" << memRefType.getRank() << ")";
+    }
     auto dataPtr = getStridedElementPtr(
         rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
     Value init = LLVM::LoadOp::create(



More information about the Mlir-commits mailing list