[Mlir-commits] [mlir] [mlir] Fix correct memset range in `OwningMemRef` zero-init (PR #158200)
Ryan Kim
llvmlistbot at llvm.org
Sat Sep 13 04:17:20 PDT 2025
https://github.com/chokobole updated https://github.com/llvm/llvm-project/pull/158200
>From 7986481d38213fb52506f7ecb223621db34a5660 Mon Sep 17 00:00:00 2001
From: Ryan Kim <chokobole33 at gmail.com>
Date: Sat, 13 Sep 2025 20:16:58 +0900
Subject: [PATCH] [mlir] Fix correct memset range in OwningMemRef zero-init
`OwningMemRef` previously called `memset()` on `descriptor.data` with
`size + desiredAlignment`, which could write past the allocated region
since `data != alignedData`.
---
.../include/mlir/ExecutionEngine/MemRefUtils.h | 10 ++++------
mlir/unittests/ExecutionEngine/Invoke.cpp | 18 ++++++++++++++++++
2 files changed, 22 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
index d66d757cb7a8e..e9471731afe13 100644
--- a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
@@ -164,19 +164,17 @@ class OwningMemRef {
int64_t nElements = 1;
for (int64_t s : shapeAlloc)
nElements *= s;
- auto [data, alignedData] =
+ auto [allocatedPtr, alignedData] =
detail::allocAligned<T>(nElements, allocFun, alignment);
- descriptor = detail::makeStridedMemRefDescriptor<Rank>(data, alignedData,
- shape, shapeAlloc);
+ descriptor = detail::makeStridedMemRefDescriptor<Rank>(
+ allocatedPtr, alignedData, shape, shapeAlloc);
if (init) {
for (StridedMemrefIterator<T, Rank> it = descriptor.begin(),
end = descriptor.end();
it != end; ++it)
init(*it, it.getIndices());
} else {
- memset(descriptor.data, 0,
- nElements * sizeof(T) +
- alignment.value_or(detail::nextPowerOf2(sizeof(T))));
+ memset(alignedData, 0, nElements * sizeof(T));
}
}
/// Take ownership of an existing descriptor with a custom deleter.
diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp
index cdeeca20610f0..3161c7053f7a4 100644
--- a/mlir/unittests/ExecutionEngine/Invoke.cpp
+++ b/mlir/unittests/ExecutionEngine/Invoke.cpp
@@ -251,6 +251,24 @@ TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) {
EXPECT_EQ((a[{2, 1}]), 42.);
}
+TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(OwningMemrefZeroInit)) {
+ constexpr int k = 3;
+ constexpr int m = 7;
+ int64_t shape[] = {k, m};
+ // Use a large alignment to stress the case where the memref data/basePtr are
+ // disjoint.
+ int alignment = 8192;
+ OwningMemRef<float, 2> a(shape, {}, {}, alignment);
+ ASSERT_EQ(
+ (void *)(((uintptr_t)a->basePtr + alignment - 1) & ~(alignment - 1)),
+ a->data);
+ for (int i = 0; i < k; ++i) {
+ for (int j = 0; j < m; ++j) {
+ EXPECT_EQ((a[{i, j}]), 0.);
+ }
+ }
+}
+
// A helper function that will be called from the JIT
static void memrefMultiply(::StridedMemRefType<float, 2> *memref,
int32_t coefficient) {
More information about the Mlir-commits
mailing list