[Mlir-commits] [mlir] [mlir][memref] Fix runtime verification for memref.subview for empty memref subviews (PR #166581)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 5 08:30:32 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: Hanumanth (Hanumanth04)
<details>
<summary>Changes</summary>
This PR applies the same fix from #<!-- -->166569 to `memref.subview`. That PR fixed the issue for `tensor.extract_slice`, and this one addresses the identical problem for `memref.subview`.
The runtime verification for `memref.subview` incorrectly rejects valid empty subviews (size=0) starting at the memref boundary.
**Example that demonstrates the issue:**
```mlir
func.func @<!-- -->subview_with_empty_slice(%memref: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>,
%dim_0: index,
%dim_1: index,
%dim_2: index,
%offset: index) {
// When called with: offset=10, dim_0=0, dim_1=4, dim_2=1
// Runtime verification fails: "offset 0 is out-of-bounds"
%subview = memref.subview %memref[%offset, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] :
memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to
memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
return
}
```
When `%offset=10` and `%dim_0=0`, we're creating an empty subview (zero elements along dimension 0) starting at the boundary. The current verification enforces `offset < dim_size`, which evaluates to `10 < 10` and fails. I feel this should be valid since no memory is accessed.
**The fix:**
Same as #<!-- -->166569 - make the offset check conditional on subview size:
- Empty subview (size == 0): allow `0 <= offset <= dim_size`
- Non-empty subview (size > 0): require `0 <= offset < dim_size`
Please see #<!-- -->166569 for motivation and rationale.
---
---
Full diff: https://github.com/llvm/llvm-project/pull/166581.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+55-34)
- (modified) mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir (+15)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 14152c5a1af0c..e5cc41e2c43ba 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -268,61 +268,82 @@ struct SubViewOpInterface
MemRefType sourceType = subView.getSource().getType();
// For each dimension, assert that:
- // 0 <= offset < dim_size
- // 0 <= offset + (size - 1) * stride < dim_size
+ // For empty slices (size == 0) : 0 <= offset <= dim_size
+ // For non-empty slices (size > 0): 0 <= offset < dim_size
+ // 0 <= offset + (size - 1) * stride
+ // dim_size
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
+
auto metadataOp =
ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
+
for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
- // Reset insertion point to before the operation for each dimension
+ // Reset insertion point to before the operation for each dimension.
builder.setInsertionPoint(subView);
+
Value offset = getValueOrCreateConstantIndexOp(
builder, loc, subView.getMixedOffsets()[i]);
Value size = getValueOrCreateConstantIndexOp(builder, loc,
subView.getMixedSizes()[i]);
Value stride = getValueOrCreateConstantIndexOp(
builder, loc, subView.getMixedStrides()[i]);
-
- // Verify that offset is in-bounds.
Value dimSize = metadataOp.getSizes()[i];
- Value offsetInBounds =
- generateInBoundsCheck(builder, loc, offset, zero, dimSize);
- cf::AssertOp::create(builder, loc, offsetInBounds,
+
+ // Verify that offset is in-bounds (conditional on slice size).
+ Value sizeIsZero = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, size, zero);
+ auto offsetCheckIf = scf::IfOp::create(
+ builder, loc, sizeIsZero,
+ [&](OpBuilder &b, Location loc) {
+ // For empty slices, offset can be at the boundary: 0 <= offset <=
+ // dimSize.
+ Value offsetGEZero = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::sge, offset, zero);
+ Value offsetLEDimSize = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::sle, offset, dimSize);
+ Value emptyOffsetValid =
+ arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize);
+ scf::YieldOp::create(b, loc, emptyOffsetValid);
+ },
+ [&](OpBuilder &b, Location loc) {
+ // For non-empty slices, offset must be a valid index: 0 <= offset
+ // dimSize.
+ Value offsetInBounds =
+ generateInBoundsCheck(b, loc, offset, zero, dimSize);
+ scf::YieldOp::create(b, loc, offsetInBounds);
+ });
+
+ Value offsetCondition = offsetCheckIf.getResult(0);
+ cf::AssertOp::create(builder, loc, offsetCondition,
generateErrorMessage(op, "offset " +
std::to_string(i) +
" is out-of-bounds"));
- // Only verify if size > 0
+ // Verify that the slice endpoint is in-bounds (only for non-empty
+ // slices).
Value sizeIsNonZero = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::sgt, size, zero);
+ auto ifOp = scf::IfOp::create(
+ builder, loc, sizeIsNonZero,
+ [&](OpBuilder &b, Location loc) {
+ // Verify that slice does not run out-of-bounds.
+ Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one);
+ Value sizeMinusOneTimesStride =
+ arith::MulIOp::create(b, loc, sizeMinusOne, stride);
+ Value lastPos =
+ arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride);
+ Value lastPosInBounds =
+ generateInBoundsCheck(b, loc, lastPos, zero, dimSize);
+ scf::YieldOp::create(b, loc, lastPosInBounds);
+ },
+ [&](OpBuilder &b, Location loc) {
+ Value trueVal =
+ arith::ConstantOp::create(b, loc, b.getBoolAttr(true));
+ scf::YieldOp::create(b, loc, trueVal);
+ });
- auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
- sizeIsNonZero, /*withElseRegion=*/true);
-
- // Populate the "then" region (for size > 0).
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-
- // Verify that slice does not run out-of-bounds.
- Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
- Value sizeMinusOneTimesStride =
- arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
- Value lastPos =
- arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
- Value lastPosInBounds =
- generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
-
- scf::YieldOp::create(builder, loc, lastPosInBounds);
-
- // Populate the "else" region (for size == 0).
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- Value trueVal =
- arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
- scf::YieldOp::create(builder, loc, trueVal);
-
- builder.setInsertionPointAfter(ifOp);
Value finalCondition = ifOp.getResult(0);
-
cf::AssertOp::create(
builder, loc, finalCondition,
generateErrorMessage(op,
diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
index 84875675ac3d0..09cfee16ccd00 100644
--- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
@@ -50,6 +50,17 @@ func.func @subview_zero_size_dim(%memref: memref<10x4x1xf32, strided<[?, ?, ?],
return
}
+func.func @subview_with_empty_slice(%memref: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>,
+ %dim_0: index,
+ %dim_1: index,
+ %dim_2: index,
+ %offset: index) {
+ %subview = memref.subview %memref[%offset, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] :
+ memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to
+ memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+ return
+}
+
func.func @main() {
%0 = arith.constant 0 : index
@@ -127,5 +138,9 @@ func.func @main() {
func.call @subview_zero_size_dim(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2)
: (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index) -> ()
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ %offset = arith.constant 10 : index
+ func.call @subview_with_empty_slice(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2, %offset)
+ : (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index, index) -> ()
return
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/166581
More information about the Mlir-commits
mailing list