[Mlir-commits] [mlir] [mlir][memref] Improve runtime verification for `memref.subview` (PR #132545)

lorenzo chelini llvmlistbot at llvm.org
Mon Mar 31 08:44:09 PDT 2025


================
@@ -327,47 +327,51 @@ struct ReinterpretCastOpInterface
   }
 };
 
-/// Verifies that the linear bounds of a subview op are within the linear bounds
-/// of the base memref: low >= baseLow && high <= baseHigh
-/// TODO: This is not yet a full runtime verification of subview. For example,
-/// consider:
-///   %m = memref.alloc(%c10, %c10) : memref<10x10xf32>
-///   memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1]
-///      : memref<?x?xf32> to memref<?x?xf32>
-/// The subview is in-bounds of the entire base memref but the first dimension
-/// is out-of-bounds. Future work would verify the bounds on a per-dimension
-/// basis.
 struct SubViewOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
                                                          SubViewOp> {
   void generateRuntimeVerification(Operation *op, OpBuilder &builder,
                                    Location loc) const {
     auto subView = cast<SubViewOp>(op);
-    auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource());
-    auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult());
+    MemRefType sourceType = subView.getSource().getType();
 
-    builder.setInsertionPointAfter(op);
-
-    // Compute the linear bounds of the base memref
-    auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
-
-    // Compute the linear bounds of the resulting memref
-    auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
-
-    // Check low >= baseLow
-    auto geLow = builder.createOrFold<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::sge, low, baseLow);
-
-    // Check high <= baseHigh
-    auto leHigh = builder.createOrFold<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::sle, high, baseHigh);
-
-    auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
-
-    builder.create<cf::AssertOp>(
-        loc, assertCond,
-        RuntimeVerifiableOpInterface::generateErrorMessage(
-            op, "subview is out-of-bounds of the base memref"));
+    // For each dimension, assert that:
+    // 0 <= offset < dim_size
+    // 0 <= offset + (size - 1) * stride < dim_size
+    Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+    auto metadataOp =
+        builder.create<ExtractStridedMetadataOp>(loc, subView.getSource());
+    for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+      Value offset = getValueOrCreateConstantIndexOp(
+          builder, loc, subView.getMixedOffsets()[i]);
+      Value size = getValueOrCreateConstantIndexOp(builder, loc,
+                                                   subView.getMixedSizes()[i]);
+      Value stride = getValueOrCreateConstantIndexOp(
+          builder, loc, subView.getMixedStrides()[i]);
+
+      // Verify that offset is in-bounds.
+      Value offsetInBounds = generateInBoundsCheck(builder, loc, offset, zero,
+                                                   metadataOp.getSizes()[i]);
+      builder.create<cf::AssertOp>(
+          loc, offsetInBounds,
+          RuntimeVerifiableOpInterface::generateErrorMessage(
+              op, "offset " + std::to_string(i) + " is out-of-bounds"));
+
+      // Verify that slice does not run out-of-bounds.
+      Value sizeMinusOne = builder.create<arith::SubIOp>(loc, size, one);
+      Value sizeMinusOneTimesStride =
+          builder.create<arith::MulIOp>(loc, sizeMinusOne, stride);
+      Value lastPos =
+          builder.create<arith::AddIOp>(loc, offset, sizeMinusOneTimesStride);
+      Value lastPosInBounds = generateInBoundsCheck(builder, loc, lastPos, zero,
+                                                    metadataOp.getSizes()[i]);
----------------
chelini wrote:

nit: I would make: metadataOp.getSizes()[i] a variable: Value dimSize = metadataOp.getSizes()[i];

https://github.com/llvm/llvm-project/pull/132545


More information about the Mlir-commits mailing list