[Mlir-commits] [mlir] [MLIR][MemRef] Validate linear size before lowering allocs (PR #179155)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Feb 1 18:26:49 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Stefan Weigl-Bosker (sweiglbosker)
<details>
<summary>Changes</summary>
See discussion here: https://github.com/llvm/llvm-project/pull/178395, https://github.com/llvm/llvm-project/pull/178994
We allow memrefs that have more elements than can be represented by an `int64_t`, so there are cases where the element count, sizes, strides etc may not be able to fit in the memref descriptor.
This pr will not try to lower memref allocations to llvm if it detects overflow in statically known strides or element count.
current limitations: doesnt consider size in bytes (so still some overflow cases with memref.alloc, but this isnt new behavior)
Fixes #<!-- -->179144 as well
---
Full diff: https://github.com/llvm/llvm-project/pull/179155.diff
8 Files Affected:
- (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+7-6)
- (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (+3-2)
- (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+4-2)
- (modified) mlir/lib/Conversion/LLVMCommon/Pattern.cpp (+15-3)
- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+20-8)
- (modified) mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp (+3-2)
- (modified) mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir (+30)
- (modified) mlir/test/Dialect/MemRef/high-rank-overflow.mlir (-2)
``````````diff
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index cacd500d41291..5e8782181b7b0 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -167,12 +167,13 @@ class ConvertToLLVMPattern : public ConversionPattern {
/// `strides[1]` = llvm.mlir.constant(1 : index) : i64
/// `strides[0]` = `sizes[0]`
/// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
- void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
- ValueRange dynamicSizes,
- ConversionPatternRewriter &rewriter,
- SmallVectorImpl<Value> &sizes,
- SmallVectorImpl<Value> &strides, Value &size,
- bool sizeInBytes = true) const;
+ LogicalResult getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
+ ValueRange dynamicSizes,
+ ConversionPatternRewriter &rewriter,
+ SmallVectorImpl<Value> &sizes,
+ SmallVectorImpl<Value> &strides,
+ Value &size,
+ bool sizeInBytes = true) const;
/// Computes the size of type in bytes.
Value getSizeInBytes(Location loc, Type type,
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 498bea0fd17b4..94ee7ca00b09f 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -759,8 +759,9 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
// Step 5. Create a memref descriptor
SmallVector<Value> shape, strides;
Value sizeBytes;
- getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
- sizeBytes);
+ if (failed(getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape,
+ strides, sizeBytes)))
+ return failure();
auto memRefDescriptor = this->createMemRefDescriptor(
loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 0f72bf0c0d59e..b775204ef04e8 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -778,8 +778,10 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
SmallVector<Value, 4> shape;
SmallVector<Value, 4> strides;
Value sizeBytes;
- getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
- shape, strides, sizeBytes);
+ if (failed(getMemRefDescriptorSizes(loc, memRefType,
+ adaptor.getDynamicSizes(), rewriter,
+ shape, strides, sizeBytes)))
+ return failure();
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 640ff3d7c3c7d..abdcc76aaa9a8 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "llvm/Support/CheckedArithmetic.h"
using namespace mlir;
@@ -85,7 +86,7 @@ Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
}
-void ConvertToLLVMPattern::getMemRefDescriptorSizes(
+LogicalResult ConvertToLLVMPattern::getMemRefDescriptorSizes(
Location loc, MemRefType memRefType, ValueRange dynamicSizes,
ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const {
@@ -107,6 +108,7 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
// Strides: iterate sizes in reverse order and multiply.
int64_t stride = 1;
+ unsigned indexBitWidth = getTypeConverter()->getIndexTypeBitwidth();
Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1);
strides.resize(memRefType.getRank());
for (auto i = memRefType.getRank(); i-- > 0;) {
@@ -116,8 +118,16 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
bool useSizeAsStride = stride == 1;
if (staticSize == ShapedType::kDynamic)
stride = ShapedType::kDynamic;
- if (stride != ShapedType::kDynamic)
- stride *= staticSize;
+ if (stride != ShapedType::kDynamic) {
+ auto res = llvm::checkedMul(stride, staticSize);
+ if (!res)
+ return failure();
+ stride = res.value();
+
+ if (stride < 0 ||
+ !llvm::isUIntN(indexBitWidth, static_cast<uint64_t>(stride)))
+ return failure();
+ }
if (useSizeAsStride)
runningStride = sizes[i];
@@ -138,6 +148,8 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
} else {
size = runningStride;
}
+
+ return success();
}
Value ConvertToLLVMPattern::getSizeInBytes(
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 91a0c4b55fa84..d5585ff10e0c4 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -164,8 +164,11 @@ class AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
SmallVector<Value, 4> strides;
Value sizeBytes;
- this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
- rewriter, sizes, strides, sizeBytes, true);
+ if (failed(this->getMemRefDescriptorSizes(loc, memRefType,
+ adaptor.getOperands(), rewriter,
+ sizes, strides, sizeBytes, true)))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute memref descriptor sizes");
Value alignment = getAlignment(rewriter, loc, op);
if (alignment) {
@@ -256,8 +259,11 @@ class AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
SmallVector<Value, 4> strides;
Value sizeBytes;
- this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
- rewriter, sizes, strides, sizeBytes, !false);
+ if (failed(this->getMemRefDescriptorSizes(
+ loc, memRefType, adaptor.getOperands(), rewriter, sizes, strides,
+ sizeBytes, !false)))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute memref descriptor sizes");
int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
@@ -349,8 +355,11 @@ struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> {
SmallVector<Value, 4> strides;
Value size;
- this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
- rewriter, sizes, strides, size, !true);
+ if (failed(this->getMemRefDescriptorSizes(loc, memRefType,
+ adaptor.getOperands(), rewriter,
+ sizes, strides, size, !true)))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute memref descriptor sizes");
// With alloca, one gets a pointer to the element type right away.
// For stack allocations.
@@ -884,8 +893,11 @@ struct GetGlobalMemrefOpLowering
SmallVector<Value, 4> strides;
Value sizeBytes;
- this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
- rewriter, sizes, strides, sizeBytes, !false);
+ if (failed(this->getMemRefDescriptorSizes(
+ loc, memRefType, adaptor.getOperands(), rewriter, sizes, strides,
+ sizeBytes, !false)))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute memref descriptor sizes");
MemRefType type = cast<MemRefType>(op.getResult().getType());
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index 2491c7cbd3d22..c61232dec330d 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -223,8 +223,9 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
SmallVector<Value, 4> sizes;
SmallVector<Value, 4> strides;
Value sizeBytes;
- getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
- sizeBytes);
+ if (failed(getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes,
+ strides, sizeBytes)))
+ return failure();
MemRefDescriptor descriptor(operand.value());
Value src = descriptor.allocatedPtr(rewriter, loc);
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 0cbe064572911..2f02dd2e51772 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -830,3 +830,33 @@ func.func @alloca_unconvertable_memory_space() {
%alloca = memref.alloca() : memref<1x32x33xi32, #spirv.storage_class<StorageBuffer>>
func.return
}
+
+// -----
+
+// CHECK-LABEL: func @alloca_huge(
+// CHECK32-LABEL: func @alloca_huge(
+func.func @alloca_huge(%arg0 : index) {
+ // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(8589934580 : index) : i64
+ // CHECK: llvm.mlir.constant(1 : index) : i64
+ // CHECK: alloca %[[SIZE]] x i32 : (i64) -> !llvm.ptr
+ // CHECK: %[[UNDEF:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[STORE0:.*]] = llvm.insertvalue %2, %[[UNDEF]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[STORE1:.*]] = llvm.insertvalue %2, %[[STORE0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
+ // CHECK: %[[STORE2:.*]] = llvm.insertvalue %[[ZERO]], %[[STORE1]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[STORE3:.*]] = llvm.insertvalue %0, %[[STORE2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: llvm.insertvalue %1, %[[STORE3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)
+ // CHECK32: memref.alloca
+ %0 = memref.alloca() : memref<8589934580xi32>
+
+ // CHECK: memref.alloca
+ // CHECK32: memref.alloca
+ %1 = memref.alloca() : memref<9223372036854775807x2xi32>
+ // CHECK: memref.alloc
+ %2 = memref.alloc() : memref<9223372036854775807x2xi32>
+
+ // CHECK: memref.alloc
+ %3 = memref.alloc(%arg0) : memref<?x8589934580x17179869160xi8>
+
+ func.return
+}
diff --git a/mlir/test/Dialect/MemRef/high-rank-overflow.mlir b/mlir/test/Dialect/MemRef/high-rank-overflow.mlir
index c0dd817ccf329..2a6ec113c7261 100644
--- a/mlir/test/Dialect/MemRef/high-rank-overflow.mlir
+++ b/mlir/test/Dialect/MemRef/high-rank-overflow.mlir
@@ -1,5 +1,3 @@
-// XFAIL: ubsan
-
// RUN: mlir-opt %s --convert-to-llvm --split-input-file --verify-diagnostics | FileCheck %s
// Test that extremely high-rank memrefs with overflow in stride calculation
``````````
</details>
https://github.com/llvm/llvm-project/pull/179155
More information about the Mlir-commits
mailing list