[Mlir-commits] [mlir] 1b455df - [mlir][memref] Add runtime verification for `memref.copy` (#130437)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 11 05:20:59 PDT 2025
Author: Matthias Springer
Date: 2025-03-11T13:20:48+01:00
New Revision: 1b455df780ed1d578b63f969c636fe78b2eb6014
URL: https://github.com/llvm/llvm-project/commit/1b455df780ed1d578b63f969c636fe78b2eb6014
DIFF: https://github.com/llvm/llvm-project/commit/1b455df780ed1d578b63f969c636fe78b2eb6014.diff
LOG: [mlir][memref] Add runtime verification for `memref.copy` (#130437)
Implement runtime op verification for `memref.copy`. Only ranked memrefs
are verified at the moment.
Added:
mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
Modified:
mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index f93ae0a7a298f..53a618d787333 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -128,6 +128,50 @@ struct CastOpInterface
}
};
+struct CopyOpInterface
+ : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
+ CopyOp> {
+ void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+ Location loc) const {
+ auto copyOp = cast<CopyOp>(op);
+ BaseMemRefType sourceType = copyOp.getSource().getType();
+ BaseMemRefType targetType = copyOp.getTarget().getType();
+ auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
+ auto rankedTargetType = dyn_cast<MemRefType>(targetType);
+
+ // TODO: Verification for unranked memrefs is not supported yet.
+ if (!rankedSourceType || !rankedTargetType)
+ return;
+
+ assert(sourceType.getRank() == targetType.getRank() && "rank mismatch");
+ for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+ // Fully static dimensions in both source and target operand are already
+ // verified by the op verifier.
+ if (!rankedSourceType.isDynamicDim(i) &&
+ !rankedTargetType.isDynamicDim(i))
+ continue;
+ auto getDimSize = [&](Value memRef, MemRefType type,
+ int64_t dim) -> Value {
+ return type.isDynamicDim(dim)
+ ? builder.create<DimOp>(loc, memRef, dim).getResult()
+ : builder
+ .create<arith::ConstantIndexOp>(loc,
+ type.getDimSize(dim))
+ .getResult();
+ };
+ Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
+ Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
+ Value sameDimSize = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
+ builder.create<cf::AssertOp>(
+ loc, sameDimSize,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "size of " + std::to_string(i) +
+ "-th source/target dim does not match"));
+ }
+ }
+};
+
/// Verifies that the indices on load/store ops are in-bounds of the memref's
/// index space: 0 <= index#i < dim#i
template <typename LoadStoreOp>
@@ -335,6 +379,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
CastOp::attachInterface<CastOpInterface>(*ctx);
+ CopyOp::attachInterface<CopyOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
diff --git a/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
new file mode 100644
index 0000000000000..95b9db2832cee
--- /dev/null
+++ b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -expand-strided-metadata \
+// RUN: -test-cf-assert \
+// RUN: -convert-to-llvm | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+// Put memref.copy in a function, otherwise the memref.cast may fold.
+func.func @memcpy_helper(%src: memref<?xf32>, %dest: memref<?xf32>) {
+ memref.copy %src, %dest : memref<?xf32> to memref<?xf32>
+ return
+}
+
+func.func @main() {
+ %alloca1 = memref.alloca() : memref<4xf32>
+ %alloca2 = memref.alloca() : memref<5xf32>
+ %cast1 = memref.cast %alloca1 : memref<4xf32> to memref<?xf32>
+ %cast2 = memref.cast %alloca2 : memref<5xf32> to memref<?xf32>
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.copy"(%{{.*}}, %{{.*}}) : (memref<?xf32>, memref<?xf32>) -> ()
+ // CHECK-NEXT: ^ size of 0-th source/target dim does not match
+ // CHECK-NEXT: Location: loc({{.*}})
+ call @memcpy_helper(%cast1, %cast2) : (memref<?xf32>, memref<?xf32>) -> ()
+
+ return
+}
More information about the Mlir-commits
mailing list