[Mlir-commits] [mlir] [mlir][memref] Improve runtime verification for `memref.subview` (PR #132545)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 22 05:23:11 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit addresses a TODO in the runtime verification of `memref.subview`. Each dimension is now verified: the offset must be in-bounds and the slice must not run out-of-bounds.
---
Full diff: https://github.com/llvm/llvm-project/pull/132545.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+38-34)
- (modified) mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir (+27-15)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 134e8b5efcfdf..7cd4814bf88d0 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -327,47 +327,51 @@ struct ReinterpretCastOpInterface
}
};
-/// Verifies that the linear bounds of a subview op are within the linear bounds
-/// of the base memref: low >= baseLow && high <= baseHigh
-/// TODO: This is not yet a full runtime verification of subview. For example,
-/// consider:
-/// %m = memref.alloc(%c10, %c10) : memref<10x10xf32>
-/// memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1]
-/// : memref<?x?xf32> to memref<?x?xf32>
-/// The subview is in-bounds of the entire base memref but the first dimension
-/// is out-of-bounds. Future work would verify the bounds on a per-dimension
-/// basis.
struct SubViewOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto subView = cast<SubViewOp>(op);
- auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource());
- auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult());
+ MemRefType sourceType = subView.getSource().getType();
- builder.setInsertionPointAfter(op);
-
- // Compute the linear bounds of the base memref
- auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
-
- // Compute the linear bounds of the resulting memref
- auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
-
- // Check low >= baseLow
- auto geLow = builder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, low, baseLow);
-
- // Check high <= baseHigh
- auto leHigh = builder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sle, high, baseHigh);
-
- auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
-
- builder.create<cf::AssertOp>(
- loc, assertCond,
- RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "subview is out-of-bounds of the base memref"));
+ // For each dimension, assert that:
+ // 0 <= offset < dim_size
+ // 0 <= offset + (size - 1) * stride < dim_size
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ auto metadataOp =
+ builder.create<ExtractStridedMetadataOp>(loc, subView.getSource());
+ for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+ Value offset = getValueOrCreateConstantIndexOp(
+ builder, loc, subView.getMixedOffsets()[i]);
+ Value size = getValueOrCreateConstantIndexOp(builder, loc,
+ subView.getMixedSizes()[i]);
+ Value stride = getValueOrCreateConstantIndexOp(
+ builder, loc, subView.getMixedStrides()[i]);
+
+ // Verify that offset is in-bounds.
+ Value offsetInBounds = generateInBoundsCheck(builder, loc, offset, zero,
+ metadataOp.getSizes()[i]);
+ builder.create<cf::AssertOp>(
+ loc, offsetInBounds,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "offset " + std::to_string(i) + " is out-of-bounds"));
+
+ // Verify that slice does not run out-of-bounds.
+ Value sizeMinusOne = builder.create<arith::SubIOp>(loc, size, one);
+ Value sizeMinusOneTimesStride =
+ builder.create<arith::MulIOp>(loc, sizeMinusOne, stride);
+ Value lastPos =
+ builder.create<arith::AddIOp>(loc, offset, sizeMinusOneTimesStride);
+ Value lastPosInBounds = generateInBoundsCheck(builder, loc, lastPos, zero,
+ metadataOp.getSizes()[i]);
+ builder.create<cf::AssertOp>(
+ loc, lastPosInBounds,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "subview runs out-of-bounds along dimension " +
+ std::to_string(i)));
+ }
}
};
diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
index 3cac37a082c30..ec7e4085f2fa5 100644
--- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
@@ -39,38 +39,50 @@ func.func @main() {
%alloca_4 = memref.alloca() : memref<4x4xf32>
%alloca_4_dyn = memref.cast %alloca_4 : memref<4x4xf32> to memref<?x4xf32>
- // Offset is out-of-bounds
+ // Offset is out-of-bounds and slice runs out-of-bounds
// CHECK: ERROR: Runtime op verification failed
- // CHECK-NEXT: "memref.subview"
- // CHECK-NEXT: ^ subview is out-of-bounds of the base memref
+ // CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 1, 1>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1>, static_strides = array<i64: -9223372036854775808, 1>}> : (memref<?x4xf32>, index, index, index) -> memref<?xf32, strided<[?], offset: ?>>
+ // CHECK-NEXT: ^ offset 0 is out-of-bounds
+ // CHECK-NEXT: Location: loc({{.*}})
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 1, 1>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1>, static_strides = array<i64: -9223372036854775808, 1>}> : (memref<?x4xf32>, index, index, index) -> memref<?xf32, strided<[?], offset: ?>>
+ // CHECK-NEXT: ^ subview runs out-of-bounds along dimension 0
// CHECK-NEXT: Location: loc({{.*}})
func.call @subview_dynamic_rank_reduce(%alloca_4_dyn, %5, %5, %1) : (memref<?x4xf32>, index, index, index) -> ()
- // Offset is out-of-bounds
+ // Offset is out-of-bounds and slice runs out-of-bounds
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 1>, static_strides = array<i64: 1>}> : (memref<1xf32>, index) -> memref<1xf32, strided<[1], offset: ?>>
+ // CHECK-NEXT: ^ offset 0 is out-of-bounds
+ // CHECK-NEXT: Location: loc({{.*}})
// CHECK: ERROR: Runtime op verification failed
- // CHECK-NEXT: "memref.subview"
- // CHECK-NEXT: ^ subview is out-of-bounds of the base memref
+ // CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 1>, static_strides = array<i64: 1>}> : (memref<1xf32>, index) -> memref<1xf32, strided<[1], offset: ?>>
+ // CHECK-NEXT: ^ subview runs out-of-bounds along dimension 0
// CHECK-NEXT: Location: loc({{.*}})
func.call @subview(%alloca, %1) : (memref<1xf32>, index) -> ()
- // Offset is out-of-bounds
+ // Offset is out-of-bounds and slice runs out-of-bounds
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 1>, static_strides = array<i64: 1>}> : (memref<1xf32>, index) -> memref<1xf32, strided<[1], offset: ?>>
+ // CHECK-NEXT: ^ offset 0 is out-of-bounds
+ // CHECK-NEXT: Location: loc({{.*}})
// CHECK: ERROR: Runtime op verification failed
- // CHECK-NEXT: "memref.subview"
- // CHECK-NEXT: ^ subview is out-of-bounds of the base memref
+ // CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 1>, static_strides = array<i64: 1>}> : (memref<1xf32>, index) -> memref<1xf32, strided<[1], offset: ?>>
+ // CHECK-NEXT: ^ subview runs out-of-bounds along dimension 0
// CHECK-NEXT: Location: loc({{.*}})
func.call @subview(%alloca, %n1) : (memref<1xf32>, index) -> ()
- // Size is out-of-bounds
+ // Slice runs out-of-bounds due to size
// CHECK: ERROR: Runtime op verification failed
- // CHECK-NEXT: "memref.subview"
- // CHECK-NEXT: ^ subview is out-of-bounds of the base memref
+ // CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 1, 1>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 4>, static_strides = array<i64: -9223372036854775808, 1>}> : (memref<?x4xf32>, index, index, index) -> memref<?x4xf32, strided<[?, 1], offset: ?>>
+ // CHECK-NEXT: ^ subview runs out-of-bounds along dimension 0
// CHECK-NEXT: Location: loc({{.*}})
func.call @subview_dynamic(%alloca_4_dyn, %0, %5, %1) : (memref<?x4xf32>, index, index, index) -> ()
- // Stride is out-of-bounds
+ // Slice runs out-of-bounds due to stride
// CHECK: ERROR: Runtime op verification failed
- // CHECK-NEXT: "memref.subview"
- // CHECK-NEXT: ^ subview is out-of-bounds of the base memref
+ // CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 1, 1>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 4>, static_strides = array<i64: -9223372036854775808, 1>}> : (memref<?x4xf32>, index, index, index) -> memref<?x4xf32, strided<[?, 1], offset: ?>>
+ // CHECK-NEXT: ^ subview runs out-of-bounds along dimension 0
// CHECK-NEXT: Location: loc({{.*}})
func.call @subview_dynamic(%alloca_4_dyn, %0, %4, %4) : (memref<?x4xf32>, index, index, index) -> ()
``````````
</details>
https://github.com/llvm/llvm-project/pull/132545
More information about the Mlir-commits
mailing list