[Mlir-commits] [mlir] [mlir][memref] Improve runtime verification for `memref.subview` (PR #132545)
lorenzo chelini
llvmlistbot at llvm.org
Mon Mar 31 08:44:09 PDT 2025
================
@@ -327,47 +327,51 @@ struct ReinterpretCastOpInterface
}
};
-/// Verifies that the linear bounds of a subview op are within the linear bounds
-/// of the base memref: low >= baseLow && high <= baseHigh
-/// TODO: This is not yet a full runtime verification of subview. For example,
-/// consider:
-/// %m = memref.alloc(%c10, %c10) : memref<10x10xf32>
-/// memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1]
-/// : memref<?x?xf32> to memref<?x?xf32>
-/// The subview is in-bounds of the entire base memref but the first dimension
-/// is out-of-bounds. Future work would verify the bounds on a per-dimension
-/// basis.
struct SubViewOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto subView = cast<SubViewOp>(op);
- auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource());
- auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult());
+ MemRefType sourceType = subView.getSource().getType();
- builder.setInsertionPointAfter(op);
-
- // Compute the linear bounds of the base memref
- auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
-
- // Compute the linear bounds of the resulting memref
- auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
-
- // Check low >= baseLow
- auto geLow = builder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, low, baseLow);
-
- // Check high <= baseHigh
- auto leHigh = builder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sle, high, baseHigh);
-
- auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
-
- builder.create<cf::AssertOp>(
- loc, assertCond,
- RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "subview is out-of-bounds of the base memref"));
+ // For each dimension, assert that:
+ // 0 <= offset < dim_size
+ // 0 <= offset + (size - 1) * stride < dim_size
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ auto metadataOp =
+ builder.create<ExtractStridedMetadataOp>(loc, subView.getSource());
+ for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+ Value offset = getValueOrCreateConstantIndexOp(
+ builder, loc, subView.getMixedOffsets()[i]);
+ Value size = getValueOrCreateConstantIndexOp(builder, loc,
+ subView.getMixedSizes()[i]);
+ Value stride = getValueOrCreateConstantIndexOp(
+ builder, loc, subView.getMixedStrides()[i]);
+
+ // Verify that offset is in-bounds.
+ Value offsetInBounds = generateInBoundsCheck(builder, loc, offset, zero,
+ metadataOp.getSizes()[i]);
+ builder.create<cf::AssertOp>(
+ loc, offsetInBounds,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "offset " + std::to_string(i) + " is out-of-bounds"));
+
+ // Verify that slice does not run out-of-bounds.
+ Value sizeMinusOne = builder.create<arith::SubIOp>(loc, size, one);
+ Value sizeMinusOneTimesStride =
+ builder.create<arith::MulIOp>(loc, sizeMinusOne, stride);
+ Value lastPos =
+ builder.create<arith::AddIOp>(loc, offset, sizeMinusOneTimesStride);
+ Value lastPosInBounds = generateInBoundsCheck(builder, loc, lastPos, zero,
+ metadataOp.getSizes()[i]);
----------------
chelini wrote:
nit: I would make: metadataOp.getSizes()[i] a variable: Value dimSize = metadataOp.getSizes()[i];
https://github.com/llvm/llvm-project/pull/132545
More information about the Mlir-commits
mailing list