[llvm-branch-commits] [mlir] [mlir][memref] Remove runtime verification for `memref.reinterpret_cast` (PR #132547)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sat Mar 22 05:29:11 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/132547

The runtime verification code used to verify that the result of a `memref.reinterpret_cast` is in-bounds with respect to the source memref. This is incorrect: `memref.reinterpret_cast` allows users to construct almost arbitrary memref descriptors and there is no correctness expectation.

This op is supposed to be used when the user "knows what they are doing." Similarly, the static verifier of `memref.reinterpret_cast` does not verify in-bounds semantics either.

>From b52e2fde970610bb749195259d915e488d66f9c8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 22 Mar 2025 13:24:53 +0100
Subject: [PATCH] [mlir][memref] Remove runtime verification for
 `memref.reinterpret_cast`

The runtime verification code used to verify that the result of a `memref.reinterpret_cast` is in-bounds with respect to the source memref. This is incorrect: `memref.reinterpret_cast` allows users to construct almost arbitrary memref descriptors and there is no correctness expectation. This op is supposed to be used when the user "knows what they are doing." Similarly, the static verifier of `memref.reinterpret_cast` does not verify in-bounds semantics either.
---
 .../Transforms/RuntimeOpVerification.cpp      | 74 +------------------
 ...reinterpret-cast-runtime-verification.mlir | 74 -------------------
 2 files changed, 1 insertion(+), 147 deletions(-)
 delete mode 100644 mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir

diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 7cd4814bf88d0..922111e1fad1f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -255,78 +255,6 @@ struct LoadStoreOpInterface
   }
 };
 
-/// Compute the linear index for the provided strided layout and indices.
-Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
-                         ArrayRef<OpFoldResult> strides,
-                         ArrayRef<OpFoldResult> indices) {
-  auto [expr, values] = computeLinearIndex(offset, strides, indices);
-  auto index =
-      affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
-  return getValueOrCreateConstantIndexOp(builder, loc, index);
-}
-
-/// Returns two Values representing the bounds of the provided strided layout
-/// metadata. The bounds are returned as a half open interval -- [low, high).
-std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
-                                            OpFoldResult offset,
-                                            ArrayRef<OpFoldResult> strides,
-                                            ArrayRef<OpFoldResult> sizes) {
-  auto zeros = SmallVector<int64_t>(sizes.size(), 0);
-  auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
-  auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
-  auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
-  return {lowerBound, upperBound};
-}
-
-/// Returns two Values representing the bounds of the memref. The bounds are
-/// returned as a half open interval -- [low, high).
-std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
-                                            TypedValue<BaseMemRefType> memref) {
-  auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
-  auto offset = runtimeMetadata.getConstifiedMixedOffset();
-  auto strides = runtimeMetadata.getConstifiedMixedStrides();
-  auto sizes = runtimeMetadata.getConstifiedMixedSizes();
-  return computeLinearBounds(builder, loc, offset, strides, sizes);
-}
-
-/// Verifies that the linear bounds of a reinterpret_cast op are within the
-/// linear bounds of the base memref: low >= baseLow && high <= baseHigh
-struct ReinterpretCastOpInterface
-    : public RuntimeVerifiableOpInterface::ExternalModel<
-          ReinterpretCastOpInterface, ReinterpretCastOp> {
-  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
-                                   Location loc) const {
-    auto reinterpretCast = cast<ReinterpretCastOp>(op);
-    auto baseMemref = reinterpretCast.getSource();
-    auto resultMemref =
-        cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
-
-    builder.setInsertionPointAfter(op);
-
-    // Compute the linear bounds of the base memref
-    auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
-
-    // Compute the linear bounds of the resulting memref
-    auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
-
-    // Check low >= baseLow
-    auto geLow = builder.createOrFold<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::sge, low, baseLow);
-
-    // Check high <= baseHigh
-    auto leHigh = builder.createOrFold<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::sle, high, baseHigh);
-
-    auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
-
-    builder.create<cf::AssertOp>(
-        loc, assertCond,
-        RuntimeVerifiableOpInterface::generateErrorMessage(
-            op,
-            "result of reinterpret_cast is out-of-bounds of the base memref"));
-  }
-};
-
 struct SubViewOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
                                                          SubViewOp> {
@@ -430,9 +358,9 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
     DimOp::attachInterface<DimOpInterface>(*ctx);
     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
     LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
-    ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
     StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
     SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
+    // Note: There is nothing to verify for ReinterpretCastOp.
 
     // Load additional dialects of which ops may get created.
     ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
diff --git a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
deleted file mode 100644
index 601a53f4b5cd9..0000000000000
--- a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
+++ /dev/null
@@ -1,74 +0,0 @@
-// RUN: mlir-opt %s -generate-runtime-verification \
-// RUN:     -test-cf-assert \
-// RUN:     -expand-strided-metadata \
-// RUN:     -lower-affine \
-// RUN:     -convert-to-llvm | \
-// RUN: mlir-runner -e main -entry-point-result=void \
-// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
-// RUN: FileCheck %s
-
-func.func @reinterpret_cast(%memref: memref<1xf32>, %offset: index) {
-    memref.reinterpret_cast %memref to
-                    offset: [%offset],
-                    sizes: [1],
-                    strides: [1]
-                  : memref<1xf32> to  memref<1xf32, strided<[1], offset: ?>>
-    return
-}
-
-func.func @reinterpret_cast_fully_dynamic(%memref: memref<?xf32>, %offset: index, %size: index, %stride: index)  {
-    memref.reinterpret_cast %memref to
-                    offset: [%offset],
-                    sizes: [%size],
-                    strides: [%stride]
-                  : memref<?xf32> to  memref<?xf32, strided<[?], offset: ?>>
-    return
-}
-
-func.func @main() {
-  %0 = arith.constant 0 : index
-  %1 = arith.constant 1 : index
-  %n1 = arith.constant -1 : index
-  %4 = arith.constant 4 : index
-  %5 = arith.constant 5 : index
-
-  %alloca_1 = memref.alloca() : memref<1xf32>
-  %alloca_4 = memref.alloca() : memref<4xf32>
-  %alloca_4_dyn = memref.cast %alloca_4 : memref<4xf32> to memref<?xf32>
-
-  // Offset is out-of-bounds
-  //      CHECK: ERROR: Runtime op verification failed
-  // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
-  // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
-  // CHECK-NEXT: Location: loc({{.*}})
-  func.call @reinterpret_cast(%alloca_1, %1) : (memref<1xf32>, index) -> ()
-
-  // Offset is out-of-bounds
-  //      CHECK: ERROR: Runtime op verification failed
-  // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
-  // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
-  // CHECK-NEXT: Location: loc({{.*}})
-  func.call @reinterpret_cast(%alloca_1, %n1) : (memref<1xf32>, index) -> ()
-
-  // Size is out-of-bounds
-  //      CHECK: ERROR: Runtime op verification failed
-  // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
-  // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
-  // CHECK-NEXT: Location: loc({{.*}})
-  func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %5, %1) : (memref<?xf32>, index, index, index) -> ()
-
-  // Stride is out-of-bounds
-  //      CHECK: ERROR: Runtime op verification failed
-  // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
-  // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
-  // CHECK-NEXT: Location: loc({{.*}})
-  func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %4, %4) : (memref<?xf32>, index, index, index) -> ()
-
-  //  CHECK-NOT: ERROR: Runtime op verification failed
-  func.call @reinterpret_cast(%alloca_1, %0) : (memref<1xf32>, index) -> ()
-
-  //  CHECK-NOT: ERROR: Runtime op verification failed
-  func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %4, %1) : (memref<?xf32>, index, index, index) -> ()
-
-  return
-}



More information about the llvm-branch-commits mailing list