[Mlir-commits] [mlir] [mlir] Fix correct memset range in `OwningMemRef` zero-init (PR #158200)

Ryan Kim llvmlistbot at llvm.org
Fri Sep 12 18:47:45 PDT 2025


https://github.com/chokobole updated https://github.com/llvm/llvm-project/pull/158200

>From dd8f24998a1b0b9730d9d5735d7fafbdcdf5dd08 Mon Sep 17 00:00:00 2001
From: Ryan Kim <chokobole33 at gmail.com>
Date: Fri, 12 Sep 2025 14:12:05 +0900
Subject: [PATCH] [mlir] Fix correct memset range in OwningMemRef zero-init

`OwningMemRef` previously called `memset()` on `descriptor.data` with
`size + desiredAlignment`, which could write past the allocated region
since `data != alignedData`.
---
 .../mlir/ExecutionEngine/MemRefUtils.h        |  10 +-
 mlir/unittests/ExecutionEngine/Invoke.cpp     | 100 ++++++++++--------
 2 files changed, 62 insertions(+), 48 deletions(-)

diff --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
index d66d757cb7a8e..e9471731afe13 100644
--- a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
@@ -164,19 +164,17 @@ class OwningMemRef {
     int64_t nElements = 1;
     for (int64_t s : shapeAlloc)
       nElements *= s;
-    auto [data, alignedData] =
+    auto [allocatedPtr, alignedData] =
         detail::allocAligned<T>(nElements, allocFun, alignment);
-    descriptor = detail::makeStridedMemRefDescriptor<Rank>(data, alignedData,
-                                                           shape, shapeAlloc);
+    descriptor = detail::makeStridedMemRefDescriptor<Rank>(
+        allocatedPtr, alignedData, shape, shapeAlloc);
     if (init) {
       for (StridedMemrefIterator<T, Rank> it = descriptor.begin(),
                                           end = descriptor.end();
            it != end; ++it)
         init(*it, it.getIndices());
     } else {
-      memset(descriptor.data, 0,
-             nElements * sizeof(T) +
-                 alignment.value_or(detail::nextPowerOf2(sizeof(T))));
+      memset(alignedData, 0, nElements * sizeof(T));
     }
   }
   /// Take ownership of an existing descriptor with a custom deleter.
diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp
index cdeeca20610f0..9f1787fa8a497 100644
--- a/mlir/unittests/ExecutionEngine/Invoke.cpp
+++ b/mlir/unittests/ExecutionEngine/Invoke.cpp
@@ -205,50 +205,66 @@ TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) {
   };
   int64_t shape[] = {k, m};
   int64_t shapeAlloc[] = {k + 1, m + 1};
-  // Use a large alignment to stress the case where the memref data/basePtr are
-  // disjoint.
-  int alignment = 8192;
-  OwningMemRef<float, 2> a(shape, shapeAlloc, init, alignment);
-  ASSERT_EQ(
-      (void *)(((uintptr_t)a->basePtr + alignment - 1) & ~(alignment - 1)),
-      a->data);
-  ASSERT_EQ(a->sizes[0], k);
-  ASSERT_EQ(a->sizes[1], m);
-  ASSERT_EQ(a->strides[0], m + 1);
-  ASSERT_EQ(a->strides[1], 1);
-  for (int i = 0; i < k; ++i) {
-    for (int j = 0; j < m; ++j) {
-      EXPECT_EQ((a[{i, j}]), i * m + j);
-      EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j]));
+  auto runTest = [&](std::string_view traceName,
+                     ElementWiseVisitor<float> maybeInit) {
+    SCOPED_TRACE(traceName);
+    // Use a large alignment to stress the case where the memref data/basePtr
+    // are disjoint.
+    int alignment = 8192;
+    OwningMemRef<float, 2> a(shape, shapeAlloc, maybeInit, alignment);
+    ASSERT_EQ(
+        (void *)(((uintptr_t)a->basePtr + alignment - 1) & ~(alignment - 1)),
+        a->data);
+    ASSERT_EQ(a->sizes[0], k);
+    ASSERT_EQ(a->sizes[1], m);
+    ASSERT_EQ(a->strides[0], m + 1);
+    ASSERT_EQ(a->strides[1], 1);
+    if (maybeInit) {
+      for (int i = 0; i < k; ++i) {
+        for (int j = 0; j < m; ++j) {
+          EXPECT_EQ((a[{i, j}]), i * m + j);
+          EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j]));
+        }
+      }
+    } else {
+      for (int i = 0; i < k; ++i) {
+        for (int j = 0; j < m; ++j) {
+          EXPECT_EQ((a[{i, j}]), 0.);
+          EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j]));
+        }
+      }
     }
-  }
-  std::string moduleStr = R"mlir(
-  func.func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
-    %x = arith.constant 2 : index
-    %y = arith.constant 1 : index
-    %cst42 = arith.constant 42.0 : f32
-    memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
-    memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
-    return
-  }
-  )mlir";
-  DialectRegistry registry;
-  registerAllDialects(registry);
-  registerBuiltinDialectTranslation(registry);
-  registerLLVMDialectTranslation(registry);
-  MLIRContext context(registry);
-  OwningOpRef<ModuleOp> module =
-      parseSourceString<ModuleOp>(moduleStr, &context);
-  ASSERT_TRUE(!!module);
-  ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
-  auto jitOrError = ExecutionEngine::create(*module);
-  ASSERT_TRUE(!!jitOrError);
-  std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
+    std::string moduleStr = R"mlir(
+    func.func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
+      %x = arith.constant 2 : index
+      %y = arith.constant 1 : index
+      %cst42 = arith.constant 42.0 : f32
+      memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
+      memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
+      return
+    }
+    )mlir";
+    DialectRegistry registry;
+    registerAllDialects(registry);
+    registerBuiltinDialectTranslation(registry);
+    registerLLVMDialectTranslation(registry);
+    MLIRContext context(registry);
+    OwningOpRef<ModuleOp> module =
+        parseSourceString<ModuleOp>(moduleStr, &context);
+    ASSERT_TRUE(!!module);
+    ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
+    auto jitOrError = ExecutionEngine::create(*module);
+    ASSERT_TRUE(!!jitOrError);
+    std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
 
-  llvm::Error error = jit->invoke("rank2_memref", &*a, &*a);
-  ASSERT_TRUE(!error);
-  EXPECT_EQ(((*a)[1][2]), 42.);
-  EXPECT_EQ((a[{2, 1}]), 42.);
+    llvm::Error error = jit->invoke("rank2_memref", &*a, &*a);
+    ASSERT_TRUE(!error);
+    EXPECT_EQ(((*a)[1][2]), 42.);
+    EXPECT_EQ((a[{2, 1}]), 42.);
+  };
+
+  runTest("withInit", init);
+  runTest("withoutInit", {});
 }
 
 // A helper function that will be called from the JIT



More information about the Mlir-commits mailing list