[Mlir-commits] [mlir] [mlir] Fix correct memset range in `OwningMemRef` zero-init (PR #158200)
Ryan Kim
llvmlistbot at llvm.org
Fri Sep 12 18:47:45 PDT 2025
https://github.com/chokobole updated https://github.com/llvm/llvm-project/pull/158200
>From dd8f24998a1b0b9730d9d5735d7fafbdcdf5dd08 Mon Sep 17 00:00:00 2001
From: Ryan Kim <chokobole33 at gmail.com>
Date: Fri, 12 Sep 2025 14:12:05 +0900
Subject: [PATCH] [mlir] Fix correct memset range in OwningMemRef zero-init
`OwningMemRef` previously called `memset()` on `descriptor.data` with
`size + desiredAlignment`, which could write past the allocated region
since `data != alignedData`.
---
.../mlir/ExecutionEngine/MemRefUtils.h | 10 +-
mlir/unittests/ExecutionEngine/Invoke.cpp | 100 ++++++++++--------
2 files changed, 62 insertions(+), 48 deletions(-)
diff --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
index d66d757cb7a8e..e9471731afe13 100644
--- a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
@@ -164,19 +164,17 @@ class OwningMemRef {
int64_t nElements = 1;
for (int64_t s : shapeAlloc)
nElements *= s;
- auto [data, alignedData] =
+ auto [allocatedPtr, alignedData] =
detail::allocAligned<T>(nElements, allocFun, alignment);
- descriptor = detail::makeStridedMemRefDescriptor<Rank>(data, alignedData,
- shape, shapeAlloc);
+ descriptor = detail::makeStridedMemRefDescriptor<Rank>(
+ allocatedPtr, alignedData, shape, shapeAlloc);
if (init) {
for (StridedMemrefIterator<T, Rank> it = descriptor.begin(),
end = descriptor.end();
it != end; ++it)
init(*it, it.getIndices());
} else {
- memset(descriptor.data, 0,
- nElements * sizeof(T) +
- alignment.value_or(detail::nextPowerOf2(sizeof(T))));
+ memset(alignedData, 0, nElements * sizeof(T));
}
}
/// Take ownership of an existing descriptor with a custom deleter.
diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp
index cdeeca20610f0..9f1787fa8a497 100644
--- a/mlir/unittests/ExecutionEngine/Invoke.cpp
+++ b/mlir/unittests/ExecutionEngine/Invoke.cpp
@@ -205,50 +205,66 @@ TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) {
};
int64_t shape[] = {k, m};
int64_t shapeAlloc[] = {k + 1, m + 1};
- // Use a large alignment to stress the case where the memref data/basePtr are
- // disjoint.
- int alignment = 8192;
- OwningMemRef<float, 2> a(shape, shapeAlloc, init, alignment);
- ASSERT_EQ(
- (void *)(((uintptr_t)a->basePtr + alignment - 1) & ~(alignment - 1)),
- a->data);
- ASSERT_EQ(a->sizes[0], k);
- ASSERT_EQ(a->sizes[1], m);
- ASSERT_EQ(a->strides[0], m + 1);
- ASSERT_EQ(a->strides[1], 1);
- for (int i = 0; i < k; ++i) {
- for (int j = 0; j < m; ++j) {
- EXPECT_EQ((a[{i, j}]), i * m + j);
- EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j]));
+ auto runTest = [&](std::string_view traceName,
+ ElementWiseVisitor<float> maybeInit) {
+ SCOPED_TRACE(traceName);
+ // Use a large alignment to stress the case where the memref data/basePtr
+ // are disjoint.
+ int alignment = 8192;
+ OwningMemRef<float, 2> a(shape, shapeAlloc, maybeInit, alignment);
+ ASSERT_EQ(
+ (void *)(((uintptr_t)a->basePtr + alignment - 1) & ~(alignment - 1)),
+ a->data);
+ ASSERT_EQ(a->sizes[0], k);
+ ASSERT_EQ(a->sizes[1], m);
+ ASSERT_EQ(a->strides[0], m + 1);
+ ASSERT_EQ(a->strides[1], 1);
+ if (maybeInit) {
+ for (int i = 0; i < k; ++i) {
+ for (int j = 0; j < m; ++j) {
+ EXPECT_EQ((a[{i, j}]), i * m + j);
+ EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j]));
+ }
+ }
+ } else {
+ for (int i = 0; i < k; ++i) {
+ for (int j = 0; j < m; ++j) {
+ EXPECT_EQ((a[{i, j}]), 0.);
+ EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j]));
+ }
+ }
}
- }
- std::string moduleStr = R"mlir(
- func.func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
- %x = arith.constant 2 : index
- %y = arith.constant 1 : index
- %cst42 = arith.constant 42.0 : f32
- memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
- memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
- return
- }
- )mlir";
- DialectRegistry registry;
- registerAllDialects(registry);
- registerBuiltinDialectTranslation(registry);
- registerLLVMDialectTranslation(registry);
- MLIRContext context(registry);
- OwningOpRef<ModuleOp> module =
- parseSourceString<ModuleOp>(moduleStr, &context);
- ASSERT_TRUE(!!module);
- ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
- auto jitOrError = ExecutionEngine::create(*module);
- ASSERT_TRUE(!!jitOrError);
- std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
+ std::string moduleStr = R"mlir(
+ func.func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
+ %x = arith.constant 2 : index
+ %y = arith.constant 1 : index
+ %cst42 = arith.constant 42.0 : f32
+ memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
+ memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
+ return
+ }
+ )mlir";
+ DialectRegistry registry;
+ registerAllDialects(registry);
+ registerBuiltinDialectTranslation(registry);
+ registerLLVMDialectTranslation(registry);
+ MLIRContext context(registry);
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
+ auto jitOrError = ExecutionEngine::create(*module);
+ ASSERT_TRUE(!!jitOrError);
+ std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
- llvm::Error error = jit->invoke("rank2_memref", &*a, &*a);
- ASSERT_TRUE(!error);
- EXPECT_EQ(((*a)[1][2]), 42.);
- EXPECT_EQ((a[{2, 1}]), 42.);
+ llvm::Error error = jit->invoke("rank2_memref", &*a, &*a);
+ ASSERT_TRUE(!error);
+ EXPECT_EQ(((*a)[1][2]), 42.);
+ EXPECT_EQ((a[{2, 1}]), 42.);
+ };
+
+ runTest("withInit", init);
+ runTest("withoutInit", {});
}
// A helper function that will be called from the JIT
More information about the Mlir-commits
mailing list