[Mlir-commits] [mlir] [MLIR][MemRef] Validate linear size before lowering allocs (PR #179155)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 2 05:50:03 PST 2026


https://github.com/sweiglbosker updated https://github.com/llvm/llvm-project/pull/179155

>From 0aace16f6185207459b7d0cb3f88baea5f1b120d Mon Sep 17 00:00:00 2001
From: Stefan Weigl-Bosker <stefan at s00.xyz>
Date: Fri, 6 Feb 2026 18:13:44 -0500
Subject: [PATCH] [MLIR][MemRef] Validate linear size before lowering allocs

---
 mlir/lib/Conversion/LLVMCommon/Pattern.cpp    | 19 +++++++++++++++----
 .../MemRefToLLVM/memref-to-llvm.mlir          | 14 ++++++++++++++
 .../Dialect/MemRef/high-rank-overflow.mlir    |  2 --
 3 files changed, 29 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 640ff3d7c3c7d..b27f0d2453086 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "llvm/Support/CheckedArithmetic.h"
 
 using namespace mlir;
 
@@ -107,19 +108,29 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
 
   // Strides: iterate sizes in reverse order and multiply.
   int64_t stride = 1;
+  bool overflowed = false;
   Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1);
   strides.resize(memRefType.getRank());
   for (auto i = memRefType.getRank(); i-- > 0;) {
-    strides[i] = runningStride;
+    strides[i] = overflowed ? LLVM::PoisonOp::create(rewriter, loc, indexType)
+                            : runningStride;
 
     int64_t staticSize = memRefType.getShape()[i];
     bool useSizeAsStride = stride == 1;
     if (staticSize == ShapedType::kDynamic)
       stride = ShapedType::kDynamic;
-    if (stride != ShapedType::kDynamic)
-      stride *= staticSize;
+    if (stride != ShapedType::kDynamic) {
+      std::optional<int64_t> res = llvm::checkedMul(stride, staticSize);
 
-    if (useSizeAsStride)
+      if (!res)
+        overflowed = true;
+      else
+        stride = res.value();
+    }
+
+    if (overflowed)
+      runningStride = LLVM::PoisonOp::create(rewriter, loc, indexType);
+    else if (useSizeAsStride)
       runningStride = sizes[i];
     else if (stride == ShapedType::kDynamic)
       runningStride =
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 0cbe064572911..d2fe5ab582b71 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -830,3 +830,17 @@ func.func @alloca_unconvertable_memory_space() {
   %alloca = memref.alloca() : memref<1x32x33xi32, #spirv.storage_class<StorageBuffer>>
   func.return
 }
+
+// -----
+
+// CHECK-LABEL: func @alloca_huge(
+func.func @alloca_huge(%arg0 : index) {
+  // CHECK: %[[D0:.*]] = llvm.mlir.constant(9223372036854775807 : index) : i64
+  // CHECK: %[[D1:.*]] = llvm.mlir.constant(3 : index) : i64
+  // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: %[[NUMELTS:.*]] = llvm.mlir.poison : i64
+  // CHECK: llvm.alloca %[[NUMELTS]] x i32 : (i64) -> !llvm.ptr
+  %1 = memref.alloca() : memref<9223372036854775807x3xi32>
+
+  func.return
+}
diff --git a/mlir/test/Dialect/MemRef/high-rank-overflow.mlir b/mlir/test/Dialect/MemRef/high-rank-overflow.mlir
index c0dd817ccf329..2a6ec113c7261 100644
--- a/mlir/test/Dialect/MemRef/high-rank-overflow.mlir
+++ b/mlir/test/Dialect/MemRef/high-rank-overflow.mlir
@@ -1,5 +1,3 @@
-// XFAIL: ubsan
-
 // RUN: mlir-opt %s --convert-to-llvm --split-input-file --verify-diagnostics | FileCheck %s
 
 // Test that extremely high-rank memrefs with overflow in stride calculation



More information about the Mlir-commits mailing list