[Mlir-commits] [mlir] [mlir][memref] Fix runtime verification for memref.subview when size dimension value is 0 (PR #164897)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 23 14:17:11 PDT 2025
https://github.com/Hanumanth04 created https://github.com/llvm/llvm-project/pull/164897
Previously, the runtime verification pass would insert assertion statements with conditions that always evaluate to false for semantically valid `memref.subview` operations where one of the dimensions had a size of 0.
The `memref.subview` runtime verification logic was unconditionally generating checks for the position of the last element (`offset + (size - 1) * stride`). When `size` is 0, this causes the assertion condition to always be false, leading to runtime failures even though the operation is semantically valid.
This patch fixes the issue by making the `lastPos` check conditional. The offset is always verified, but the endpoint check is only performed when `size > 0` to avoid generating spurious assert statements.
This issue was discovered through a LiteRT model, where a dynamic shape calculation resulted in a zero-sized dimension being passed to `memref.subview`. The following is a simplified IR snippet from the model. After running the runtime verification pass, an assertion that always fails is generated because the SSA value `%5` becomes 0.
```mlir
module {
memref.global "private" constant @__constant_2xi32 : memref<2xi32> = dense<-1> {alignment = 64 : i64}
memref.global "private" constant @__constant_1xi32 : memref<1xi32> = dense<0> {alignment = 64 : i64}
func.func @simpleRepro(%arg0: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%c0 = arith.constant 0 : index
%c-1 = arith.constant -1 : index
%0 = memref.get_global @__constant_1xi32 : memref<1xi32>
%1 = memref.get_global @__constant_2xi32 : memref<2xi32>
%alloca = memref.alloca() {alignment = 64 : i64} : memref<3xi32>
%subview = memref.subview %alloca[0] [1] [1] : memref<3xi32> to memref<1xi32, strided<[1]>>
memref.copy %0, %subview : memref<1xi32> to memref<1xi32, strided<[1]>>
%subview_0 = memref.subview %alloca[1] [2] [1] : memref<3xi32> to memref<2xi32, strided<[1], offset: 1>>
memref.copy %1, %subview_0 : memref<2xi32> to memref<2xi32, strided<[1], offset: 1>>
%2 = memref.load %alloca[%c0] : memref<3xi32>
%3 = index.casts %2 : i32 to index
%4 = arith.cmpi eq, %3, %c-1 : index
%5 = arith.select %4, %c10, %3 : index
%6 = memref.load %alloca[%c1] : memref<3xi32>
%7 = index.casts %6 : i32 to index
%8 = arith.cmpi eq, %7, %c-1 : index
%9 = arith.select %8, %c4, %7 : index
%10 = memref.load %alloca[%c2] : memref<3xi32>
%11 = index.casts %10 : i32 to index
%12 = arith.cmpi eq, %11, %c-1 : index
%13 = arith.select %12, %c1, %11 : index
%subview_1 = memref.subview %arg0[0, 0, 0] [%5, %9, %13] [1, 1, 1] : memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
return %subview_1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
}
}
```
P.S. This is a similar issue to the one fixed for `tensor.extract_slice` in https://github.com/llvm/llvm-project/pull/164878
>From 7a91711c9ee70b363f5ec40e6a587ce9026f1377 Mon Sep 17 00:00:00 2001
From: Hanumanth Hanumantharayappa <hhanuman at ah-hhanuman-l.dhcp.mathworks.com>
Date: Thu, 23 Oct 2025 17:14:32 -0400
Subject: [PATCH] [mlir][memref] Fix runtime verification for memref.subview
when size is 0
---
.../Transforms/RuntimeOpVerification.cpp | 42 ++++++++++++++++++-
.../MemRef/subview-runtime-verification.mlir | 19 +++++++++
2 files changed, 60 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 291da1f76ca9b..1979d5b7e6310 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
@@ -273,7 +274,9 @@ struct SubViewOpInterface
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
auto metadataOp =
ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
- for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+ for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
+ // Reset insertion point to before the operation for each dimension
+ builder.setInsertionPoint(subView);
Value offset = getValueOrCreateConstantIndexOp(
builder, loc, subView.getMixedOffsets()[i]);
Value size = getValueOrCreateConstantIndexOp(builder, loc,
@@ -290,6 +293,42 @@ struct SubViewOpInterface
std::to_string(i) +
" is out-of-bounds"));
+ // Only verify if size > 0
+ Value sizeIsNonZero = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::sgt, size, zero);
+
+ /*
+ * Split the current block to create the below control flow structure:
+ *
+ * ^preCondBlock:
+ * ... // offset check already done above
+ * %size_nonzero = arith.cmpi sgt, %size, %zero
+ * cf.cond_br %size_nonzero, ^sizeBoundsCheckBlock, ^afterCheckBlock
+ *
+ * ^sizeBoundsCheckBlock:
+ * %last_pos = ... // compute offset + (size-1) * stride
+ * %last_pos_ok = ... // last position bounds check
+ * cf.assert %last_pos_ok, "extract_slice runs out-of-bounds"
+ * cf.br ^afterCheckBlock
+ *
+ * ^afterCheckBlock:
+ * tensor.extract_slice ... // the original operation
+ */
+ Block *preCondBlock = builder.getBlock();
+ Block *afterCheckBlock = preCondBlock->splitBlock(subView);
+
+ // Create the block for conditional size bounds verification.
+ Block *sizeBoundsCheckBlock = builder.createBlock(
+ preCondBlock->getParent(), Region::iterator(afterCheckBlock));
+
+ // Terminate the pre-condition block with the conditional branch.
+ builder.setInsertionPointToEnd(preCondBlock);
+ cf::CondBranchOp::create(builder, loc, sizeIsNonZero,
+ sizeBoundsCheckBlock, afterCheckBlock);
+
+ // Populate the size bounds check block with lastPos verification.
+ builder.setInsertionPointToStart(sizeBoundsCheckBlock);
+
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
Value sizeMinusOneTimesStride =
@@ -303,6 +342,7 @@ struct SubViewOpInterface
generateErrorMessage(op,
"subview runs out-of-bounds along dimension " +
std::to_string(i)));
+ cf::BranchOp::create(builder, loc, afterCheckBlock);
}
}
};
diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
index 71e813c0a6300..001c435086976 100644
--- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
@@ -38,6 +38,17 @@ func.func @subview_dynamic_rank_reduce(%memref: memref<?x4xf32>, %offset: index,
return
}
+func.func @subview_zero_size_dim(%memref: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>,
+ %dim_0: index,
+ %dim_1: index,
+ %dim_2: index) {
+ %subview = memref.subview %memref[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] :
+ memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to
+ memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+ return
+}
+
+
func.func @main() {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
@@ -105,6 +116,14 @@ func.func @main() {
// CHECK-NOT: ERROR: Runtime op verification failed
func.call @subview_dynamic_rank_reduce(%alloca_4_dyn, %0, %1, %0) : (memref<?x4xf32>, index, index, index) -> ()
+ %alloca_10x4x1 = memref.alloca() : memref<10x4x1xf32>
+ %alloca_10x4x1_dyn_stride = memref.cast %alloca_10x4x1 : memref<10x4x1xf32> to memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ %dim_0 = arith.constant 0 : index
+ %dim_1 = arith.constant 4 : index
+ %dim_2 = arith.constant 1 : index
+ func.call @subview_zero_size_dim(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2)
+ : (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index) -> ()
return
}
More information about the Mlir-commits
mailing list