[Mlir-commits] [mlir] c5e5c97 - [MLIR][MemRef] Validate linear size before lowering allocs (#179155)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 2 05:59:28 PST 2026


Author: sweiglbosker
Date: 2026-03-02T13:59:23Z
New Revision: c5e5c9735a33909e2268f99e84654e46bd98b109

URL: https://github.com/llvm/llvm-project/commit/c5e5c9735a33909e2268f99e84654e46bd98b109
DIFF: https://github.com/llvm/llvm-project/commit/c5e5c9735a33909e2268f99e84654e46bd98b109.diff

LOG: [MLIR][MemRef] Validate linear size before lowering allocs (#179155)

See discussion here: https://github.com/llvm/llvm-project/pull/178395,
https://github.com/llvm/llvm-project/pull/178994

Detect when the total number of elements overflows, and allocate poison
instead of the overflowed size.

Added: 
    

Modified: 
    mlir/lib/Conversion/LLVMCommon/Pattern.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
    mlir/test/Dialect/MemRef/high-rank-overflow.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 4592ff8425687..2e0d92c3ba847 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "llvm/Support/CheckedArithmetic.h"
 #include "llvm/Support/MathExtras.h"
 
 using namespace mlir;
@@ -108,19 +109,29 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
 
   // Strides: iterate sizes in reverse order and multiply.
   int64_t stride = 1;
+  bool overflowed = false;
   Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1);
   strides.resize(memRefType.getRank());
   for (auto i = memRefType.getRank(); i-- > 0;) {
-    strides[i] = runningStride;
+    strides[i] = overflowed ? LLVM::PoisonOp::create(rewriter, loc, indexType)
+                            : runningStride;
 
     int64_t staticSize = memRefType.getShape()[i];
     bool useSizeAsStride = stride == 1;
     if (staticSize == ShapedType::kDynamic)
       stride = ShapedType::kDynamic;
-    if (stride != ShapedType::kDynamic)
-      stride *= staticSize;
+    if (stride != ShapedType::kDynamic) {
+      std::optional<int64_t> res = llvm::checkedMul(stride, staticSize);
 
-    if (useSizeAsStride)
+      if (!res)
+        overflowed = true;
+      else
+        stride = res.value();
+    }
+
+    if (overflowed)
+      runningStride = LLVM::PoisonOp::create(rewriter, loc, indexType);
+    else if (useSizeAsStride)
       runningStride = sizes[i];
     else if (stride == ShapedType::kDynamic)
       runningStride =

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 0cbe064572911..d2fe5ab582b71 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -830,3 +830,17 @@ func.func @alloca_unconvertable_memory_space() {
   %alloca = memref.alloca() : memref<1x32x33xi32, #spirv.storage_class<StorageBuffer>>
   func.return
 }
+
+// -----
+
+// CHECK-LABEL: func @alloca_huge(
+func.func @alloca_huge(%arg0 : index) {
+  // CHECK: %[[D0:.*]] = llvm.mlir.constant(9223372036854775807 : index) : i64
+  // CHECK: %[[D1:.*]] = llvm.mlir.constant(3 : index) : i64
+  // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: %[[NUMELTS:.*]] = llvm.mlir.poison : i64
+  // CHECK: llvm.alloca %[[NUMELTS]] x i32 : (i64) -> !llvm.ptr
+  %1 = memref.alloca() : memref<9223372036854775807x3xi32>
+
+  func.return
+}

diff  --git a/mlir/test/Dialect/MemRef/high-rank-overflow.mlir b/mlir/test/Dialect/MemRef/high-rank-overflow.mlir
index c0dd817ccf329..2a6ec113c7261 100644
--- a/mlir/test/Dialect/MemRef/high-rank-overflow.mlir
+++ b/mlir/test/Dialect/MemRef/high-rank-overflow.mlir
@@ -1,5 +1,3 @@
-// XFAIL: ubsan
-
 // RUN: mlir-opt %s --convert-to-llvm --split-input-file --verify-diagnostics | FileCheck %s
 
 // Test that extremely high-rank memrefs with overflow in stride calculation


        


More information about the Mlir-commits mailing list