[Mlir-commits] [mlir] [mlir] Fix correct memset range in `OwningMemRef` zero-init (PR #158200)

Ryan Kim llvmlistbot at llvm.org
Sat Sep 13 04:17:20 PDT 2025


https://github.com/chokobole updated https://github.com/llvm/llvm-project/pull/158200

>From 7986481d38213fb52506f7ecb223621db34a5660 Mon Sep 17 00:00:00 2001
From: Ryan Kim <chokobole33 at gmail.com>
Date: Sat, 13 Sep 2025 20:16:58 +0900
Subject: [PATCH] [mlir] Fix correct memset range in OwningMemRef zero-init

`OwningMemRef` previously called `memset()` on `descriptor.data` with
`size + desiredAlignment`, which could write past the allocated region
since `data != alignedData`.
---
 .../include/mlir/ExecutionEngine/MemRefUtils.h | 10 ++++------
 mlir/unittests/ExecutionEngine/Invoke.cpp      | 18 ++++++++++++++++++
 2 files changed, 22 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
index d66d757cb7a8e..e9471731afe13 100644
--- a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
@@ -164,19 +164,17 @@ class OwningMemRef {
     int64_t nElements = 1;
     for (int64_t s : shapeAlloc)
       nElements *= s;
-    auto [data, alignedData] =
+    auto [allocatedPtr, alignedData] =
         detail::allocAligned<T>(nElements, allocFun, alignment);
-    descriptor = detail::makeStridedMemRefDescriptor<Rank>(data, alignedData,
-                                                           shape, shapeAlloc);
+    descriptor = detail::makeStridedMemRefDescriptor<Rank>(
+        allocatedPtr, alignedData, shape, shapeAlloc);
     if (init) {
       for (StridedMemrefIterator<T, Rank> it = descriptor.begin(),
                                           end = descriptor.end();
            it != end; ++it)
         init(*it, it.getIndices());
     } else {
-      memset(descriptor.data, 0,
-             nElements * sizeof(T) +
-                 alignment.value_or(detail::nextPowerOf2(sizeof(T))));
+      memset(alignedData, 0, nElements * sizeof(T));
     }
   }
   /// Take ownership of an existing descriptor with a custom deleter.
diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp
index cdeeca20610f0..3161c7053f7a4 100644
--- a/mlir/unittests/ExecutionEngine/Invoke.cpp
+++ b/mlir/unittests/ExecutionEngine/Invoke.cpp
@@ -251,6 +251,24 @@ TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) {
   EXPECT_EQ((a[{2, 1}]), 42.);
 }
 
+TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(OwningMemrefZeroInit)) {
+  constexpr int k = 3;
+  constexpr int m = 7;
+  int64_t shape[] = {k, m};
+  // Use a large alignment to stress the case where the memref data/basePtr are
+  // disjoint.
+  int alignment = 8192;
+  OwningMemRef<float, 2> a(shape, {}, {}, alignment);
+  ASSERT_EQ(
+      (void *)(((uintptr_t)a->basePtr + alignment - 1) & ~(alignment - 1)),
+      a->data);
+  for (int i = 0; i < k; ++i) {
+    for (int j = 0; j < m; ++j) {
+      EXPECT_EQ((a[{i, j}]), 0.);
+    }
+  }
+}
+
 // A helper function that will be called from the JIT
 static void memrefMultiply(::StridedMemRefType<float, 2> *memref,
                            int32_t coefficient) {



More information about the Mlir-commits mailing list