[Mlir-commits] [mlir] fd161cf - [mlir][memref] Remove runtime verification for `memref.reinterpret_cast` (#132547)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 6 00:40:31 PDT 2025
Author: Matthias Springer
Date: 2025-05-06T09:40:28+02:00
New Revision: fd161cf56f4356c38f82a6d68a80236e00bce39d
URL: https://github.com/llvm/llvm-project/commit/fd161cf56f4356c38f82a6d68a80236e00bce39d
DIFF: https://github.com/llvm/llvm-project/commit/fd161cf56f4356c38f82a6d68a80236e00bce39d.diff
LOG: [mlir][memref] Remove runtime verification for `memref.reinterpret_cast` (#132547)
The runtime verification code used to verify that the result of a
`memref.reinterpret_cast` is in-bounds with respect to the source
memref. This is incorrect: `memref.reinterpret_cast` allows users to
construct almost arbitrary memref descriptors and there is no
correctness expectation.
This op is supposed to be used when the user "knows what they are
doing." Similarly, the static verifier of `memref.reinterpret_cast` does
not verify in-bounds semantics either.
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
Removed:
mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 2e97fafdceace..7bf7c7b8e024c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -255,78 +255,6 @@ struct LoadStoreOpInterface
}
};
-/// Compute the linear index for the provided strided layout and indices.
-Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
- ArrayRef<OpFoldResult> strides,
- ArrayRef<OpFoldResult> indices) {
- auto [expr, values] = computeLinearIndex(offset, strides, indices);
- auto index =
- affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
- return getValueOrCreateConstantIndexOp(builder, loc, index);
-}
-
-/// Returns two Values representing the bounds of the provided strided layout
-/// metadata. The bounds are returned as a half open interval -- [low, high).
-std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
- OpFoldResult offset,
- ArrayRef<OpFoldResult> strides,
- ArrayRef<OpFoldResult> sizes) {
- auto zeros = SmallVector<int64_t>(sizes.size(), 0);
- auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
- auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
- auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
- return {lowerBound, upperBound};
-}
-
-/// Returns two Values representing the bounds of the memref. The bounds are
-/// returned as a half open interval -- [low, high).
-std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
- TypedValue<BaseMemRefType> memref) {
- auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
- auto offset = runtimeMetadata.getConstifiedMixedOffset();
- auto strides = runtimeMetadata.getConstifiedMixedStrides();
- auto sizes = runtimeMetadata.getConstifiedMixedSizes();
- return computeLinearBounds(builder, loc, offset, strides, sizes);
-}
-
-/// Verifies that the linear bounds of a reinterpret_cast op are within the
-/// linear bounds of the base memref: low >= baseLow && high <= baseHigh
-struct ReinterpretCastOpInterface
- : public RuntimeVerifiableOpInterface::ExternalModel<
- ReinterpretCastOpInterface, ReinterpretCastOp> {
- void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
- auto reinterpretCast = cast<ReinterpretCastOp>(op);
- auto baseMemref = reinterpretCast.getSource();
- auto resultMemref =
- cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
-
- builder.setInsertionPointAfter(op);
-
- // Compute the linear bounds of the base memref
- auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
-
- // Compute the linear bounds of the resulting memref
- auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
-
- // Check low >= baseLow
- auto geLow = builder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, low, baseLow);
-
- // Check high <= baseHigh
- auto leHigh = builder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sle, high, baseHigh);
-
- auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
-
- builder.create<cf::AssertOp>(
- loc, assertCond,
- RuntimeVerifiableOpInterface::generateErrorMessage(
- op,
- "result of reinterpret_cast is out-of-bounds of the base memref"));
- }
-};
-
struct SubViewOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
@@ -434,9 +362,9 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
GenericAtomicRMWOp::attachInterface<
LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
- ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
+ // Note: There is nothing to verify for ReinterpretCastOp.
// Load additional dialects of which ops may get created.
ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
diff --git a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
deleted file mode 100644
index 601a53f4b5cd9..0000000000000
--- a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
+++ /dev/null
@@ -1,74 +0,0 @@
-// RUN: mlir-opt %s -generate-runtime-verification \
-// RUN: -test-cf-assert \
-// RUN: -expand-strided-metadata \
-// RUN: -lower-affine \
-// RUN: -convert-to-llvm | \
-// RUN: mlir-runner -e main -entry-point-result=void \
-// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
-// RUN: FileCheck %s
-
-func.func @reinterpret_cast(%memref: memref<1xf32>, %offset: index) {
- memref.reinterpret_cast %memref to
- offset: [%offset],
- sizes: [1],
- strides: [1]
- : memref<1xf32> to memref<1xf32, strided<[1], offset: ?>>
- return
-}
-
-func.func @reinterpret_cast_fully_dynamic(%memref: memref<?xf32>, %offset: index, %size: index, %stride: index) {
- memref.reinterpret_cast %memref to
- offset: [%offset],
- sizes: [%size],
- strides: [%stride]
- : memref<?xf32> to memref<?xf32, strided<[?], offset: ?>>
- return
-}
-
-func.func @main() {
- %0 = arith.constant 0 : index
- %1 = arith.constant 1 : index
- %n1 = arith.constant -1 : index
- %4 = arith.constant 4 : index
- %5 = arith.constant 5 : index
-
- %alloca_1 = memref.alloca() : memref<1xf32>
- %alloca_4 = memref.alloca() : memref<4xf32>
- %alloca_4_dyn = memref.cast %alloca_4 : memref<4xf32> to memref<?xf32>
-
- // Offset is out-of-bounds
- // CHECK: ERROR: Runtime op verification failed
- // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
- // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
- // CHECK-NEXT: Location: loc({{.*}})
- func.call @reinterpret_cast(%alloca_1, %1) : (memref<1xf32>, index) -> ()
-
- // Offset is out-of-bounds
- // CHECK: ERROR: Runtime op verification failed
- // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
- // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
- // CHECK-NEXT: Location: loc({{.*}})
- func.call @reinterpret_cast(%alloca_1, %n1) : (memref<1xf32>, index) -> ()
-
- // Size is out-of-bounds
- // CHECK: ERROR: Runtime op verification failed
- // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
- // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
- // CHECK-NEXT: Location: loc({{.*}})
- func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %5, %1) : (memref<?xf32>, index, index, index) -> ()
-
- // Stride is out-of-bounds
- // CHECK: ERROR: Runtime op verification failed
- // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
- // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
- // CHECK-NEXT: Location: loc({{.*}})
- func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %4, %4) : (memref<?xf32>, index, index, index) -> ()
-
- // CHECK-NOT: ERROR: Runtime op verification failed
- func.call @reinterpret_cast(%alloca_1, %0) : (memref<1xf32>, index) -> ()
-
- // CHECK-NOT: ERROR: Runtime op verification failed
- func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %4, %1) : (memref<?xf32>, index, index, index) -> ()
-
- return
-}
More information about the Mlir-commits
mailing list