[Mlir-commits] [mlir] [MLIR][MemRef] Validate linear size before lowering allocs (PR #179155)
Stefan Weigl-Bosker
llvmlistbot at llvm.org
Sun Feb 1 18:26:16 PST 2026
https://github.com/sweiglbosker created https://github.com/llvm/llvm-project/pull/179155
See discussion here: https://github.com/llvm/llvm-project/pull/178395, https://github.com/llvm/llvm-project/pull/178994
We allow memrefs that have more elements than can be represented by an `int64_t`, so there are cases where the element count, sizes, strides etc may not be able to fit in the memref descriptor.
This pr will not try to lower memref allocations to llvm if it detects overflow in statically known strides or element count.
current limitations: doesnt consider size in bytes (so still some overflow cases with memref.alloc, but this isnt new behavior)
Fixes #179144 as well
>From 6db64e552056b9dcb9ebd500b5bf4adf9e809673 Mon Sep 17 00:00:00 2001
From: Stefan Weigl-Bosker <stefan at s00.xyz>
Date: Sun, 1 Feb 2026 21:03:00 -0500
Subject: [PATCH] [MLIR][MemRef] Validate linear size before lowering allocs
---
.../mlir/Conversion/LLVMCommon/Pattern.h | 13 ++++----
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 5 ++--
.../GPUCommon/GPUToLLVMConversion.cpp | 6 ++--
mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 18 +++++++++--
.../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 28 ++++++++++++-----
.../ConvertLaunchFuncToLLVMCalls.cpp | 5 ++--
.../MemRefToLLVM/memref-to-llvm.mlir | 30 +++++++++++++++++++
.../Dialect/MemRef/high-rank-overflow.mlir | 2 --
8 files changed, 82 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index cacd500d41291..5e8782181b7b0 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -167,12 +167,13 @@ class ConvertToLLVMPattern : public ConversionPattern {
/// `strides[1]` = llvm.mlir.constant(1 : index) : i64
/// `strides[0]` = `sizes[0]`
/// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
- void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
- ValueRange dynamicSizes,
- ConversionPatternRewriter &rewriter,
- SmallVectorImpl<Value> &sizes,
- SmallVectorImpl<Value> &strides, Value &size,
- bool sizeInBytes = true) const;
+ LogicalResult getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
+ ValueRange dynamicSizes,
+ ConversionPatternRewriter &rewriter,
+ SmallVectorImpl<Value> &sizes,
+ SmallVectorImpl<Value> &strides,
+ Value &size,
+ bool sizeInBytes = true) const;
/// Computes the size of type in bytes.
Value getSizeInBytes(Location loc, Type type,
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 498bea0fd17b4..94ee7ca00b09f 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -759,8 +759,9 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
// Step 5. Create a memref descriptor
SmallVector<Value> shape, strides;
Value sizeBytes;
- getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
- sizeBytes);
+ if (failed(getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape,
+ strides, sizeBytes)))
+ return failure();
auto memRefDescriptor = this->createMemRefDescriptor(
loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 0f72bf0c0d59e..b775204ef04e8 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -778,8 +778,10 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
SmallVector<Value, 4> shape;
SmallVector<Value, 4> strides;
Value sizeBytes;
- getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
- shape, strides, sizeBytes);
+ if (failed(getMemRefDescriptorSizes(loc, memRefType,
+ adaptor.getDynamicSizes(), rewriter,
+ shape, strides, sizeBytes)))
+ return failure();
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 640ff3d7c3c7d..abdcc76aaa9a8 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "llvm/Support/CheckedArithmetic.h"
using namespace mlir;
@@ -85,7 +86,7 @@ Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
}
-void ConvertToLLVMPattern::getMemRefDescriptorSizes(
+LogicalResult ConvertToLLVMPattern::getMemRefDescriptorSizes(
Location loc, MemRefType memRefType, ValueRange dynamicSizes,
ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const {
@@ -107,6 +108,7 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
// Strides: iterate sizes in reverse order and multiply.
int64_t stride = 1;
+ unsigned indexBitWidth = getTypeConverter()->getIndexTypeBitwidth();
Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1);
strides.resize(memRefType.getRank());
for (auto i = memRefType.getRank(); i-- > 0;) {
@@ -116,8 +118,16 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
bool useSizeAsStride = stride == 1;
if (staticSize == ShapedType::kDynamic)
stride = ShapedType::kDynamic;
- if (stride != ShapedType::kDynamic)
- stride *= staticSize;
+ if (stride != ShapedType::kDynamic) {
+ auto res = llvm::checkedMul(stride, staticSize);
+ if (!res)
+ return failure();
+ stride = res.value();
+
+ if (stride < 0 ||
+ !llvm::isUIntN(indexBitWidth, static_cast<uint64_t>(stride)))
+ return failure();
+ }
if (useSizeAsStride)
runningStride = sizes[i];
@@ -138,6 +148,8 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
} else {
size = runningStride;
}
+
+ return success();
}
Value ConvertToLLVMPattern::getSizeInBytes(
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 91a0c4b55fa84..d5585ff10e0c4 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -164,8 +164,11 @@ class AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
SmallVector<Value, 4> strides;
Value sizeBytes;
- this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
- rewriter, sizes, strides, sizeBytes, true);
+ if (failed(this->getMemRefDescriptorSizes(loc, memRefType,
+ adaptor.getOperands(), rewriter,
+ sizes, strides, sizeBytes, true)))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute memref descriptor sizes");
Value alignment = getAlignment(rewriter, loc, op);
if (alignment) {
@@ -256,8 +259,11 @@ class AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
SmallVector<Value, 4> strides;
Value sizeBytes;
- this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
- rewriter, sizes, strides, sizeBytes, !false);
+ if (failed(this->getMemRefDescriptorSizes(
+ loc, memRefType, adaptor.getOperands(), rewriter, sizes, strides,
+ sizeBytes, !false)))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute memref descriptor sizes");
int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
@@ -349,8 +355,11 @@ struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> {
SmallVector<Value, 4> strides;
Value size;
- this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
- rewriter, sizes, strides, size, !true);
+ if (failed(this->getMemRefDescriptorSizes(loc, memRefType,
+ adaptor.getOperands(), rewriter,
+ sizes, strides, size, !true)))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute memref descriptor sizes");
// With alloca, one gets a pointer to the element type right away.
// For stack allocations.
@@ -884,8 +893,11 @@ struct GetGlobalMemrefOpLowering
SmallVector<Value, 4> strides;
Value sizeBytes;
- this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
- rewriter, sizes, strides, sizeBytes, !false);
+ if (failed(this->getMemRefDescriptorSizes(
+ loc, memRefType, adaptor.getOperands(), rewriter, sizes, strides,
+ sizeBytes, !false)))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute memref descriptor sizes");
MemRefType type = cast<MemRefType>(op.getResult().getType());
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index 2491c7cbd3d22..c61232dec330d 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -223,8 +223,9 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
SmallVector<Value, 4> sizes;
SmallVector<Value, 4> strides;
Value sizeBytes;
- getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
- sizeBytes);
+ if (failed(getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes,
+ strides, sizeBytes)))
+ return failure();
MemRefDescriptor descriptor(operand.value());
Value src = descriptor.allocatedPtr(rewriter, loc);
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 0cbe064572911..2f02dd2e51772 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -830,3 +830,33 @@ func.func @alloca_unconvertable_memory_space() {
%alloca = memref.alloca() : memref<1x32x33xi32, #spirv.storage_class<StorageBuffer>>
func.return
}
+
+// -----
+
+// CHECK-LABEL: func @alloca_huge(
+// CHECK32-LABEL: func @alloca_huge(
+func.func @alloca_huge(%arg0 : index) {
+ // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(8589934580 : index) : i64
+ // CHECK: llvm.mlir.constant(1 : index) : i64
+ // CHECK: alloca %[[SIZE]] x i32 : (i64) -> !llvm.ptr
+ // CHECK: %[[UNDEF:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[STORE0:.*]] = llvm.insertvalue %2, %[[UNDEF]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[STORE1:.*]] = llvm.insertvalue %2, %[[STORE0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
+ // CHECK: %[[STORE2:.*]] = llvm.insertvalue %[[ZERO]], %[[STORE1]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[STORE3:.*]] = llvm.insertvalue %0, %[[STORE2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: llvm.insertvalue %1, %[[STORE3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)
+ // CHECK32: memref.alloca
+ %0 = memref.alloca() : memref<8589934580xi32>
+
+ // CHECK: memref.alloca
+ // CHECK32: memref.alloca
+ %1 = memref.alloca() : memref<9223372036854775807x2xi32>
+ // CHECK: memref.alloc
+ %2 = memref.alloc() : memref<9223372036854775807x2xi32>
+
+ // CHECK: memref.alloc
+ %3 = memref.alloc(%arg0) : memref<?x8589934580x17179869160xi8>
+
+ func.return
+}
diff --git a/mlir/test/Dialect/MemRef/high-rank-overflow.mlir b/mlir/test/Dialect/MemRef/high-rank-overflow.mlir
index c0dd817ccf329..2a6ec113c7261 100644
--- a/mlir/test/Dialect/MemRef/high-rank-overflow.mlir
+++ b/mlir/test/Dialect/MemRef/high-rank-overflow.mlir
@@ -1,5 +1,3 @@
-// XFAIL: ubsan
-
// RUN: mlir-opt %s --convert-to-llvm --split-input-file --verify-diagnostics | FileCheck %s
// Test that extremely high-rank memrefs with overflow in stride calculation
More information about the Mlir-commits
mailing list