[llvm-branch-commits] [mlir] [mlir][memref] Add runtime verification for `memref.copy` (PR #130437)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sat Mar 8 12:37:46 PST 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/130437

Implement runtime op verification for `memref.copy`. Only ranked memrefs are verified at the moment.

>From 7bb852c420fb718eec9198ec3659fbcd1221ca33 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 8 Mar 2025 21:34:30 +0100
Subject: [PATCH] [mlir][memref] Add runtime verification for `memref.copy`

---
 .../Transforms/RuntimeOpVerification.cpp      | 48 +++++++++++++++++++
 .../MemRef/copy-runtime-verification.mlir     | 28 +++++++++++
 2 files changed, 76 insertions(+)
 create mode 100644 mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir

diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index ceea27a35a225..c604af249ba2e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -182,6 +182,53 @@ struct CastOpInterface
   }
 };
 
+struct CopyOpInterface
+    : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
+                                                         CopyOp> {
+  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+                                   Location loc) const {
+    auto copyOp = cast<CopyOp>(op);
+    BaseMemRefType sourceType = copyOp.getSource().getType();
+    BaseMemRefType targetType = copyOp.getTarget().getType();
+    auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
+    auto rankedTargetType = dyn_cast<MemRefType>(targetType);
+
+    // TODO: Verification for unranked memrefs is not supported yet.
+    if (!rankedSourceType || !rankedTargetType)
+      return;
+
+    assert(sourceType.getRank() == targetType.getRank() && "rank mismatch");
+    for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+      // Fully static dimensions in both source and target operand are already
+      // verified by the op verifier.
+      if (!rankedSourceType.isDynamicDim(i) &&
+          !rankedTargetType.isDynamicDim(i))
+        continue;
+      Value sourceDim;
+      if (rankedSourceType.isDynamicDim(i)) {
+        sourceDim = builder.create<DimOp>(loc, copyOp.getSource(), i);
+      } else {
+        sourceDim = builder.create<arith::ConstantIndexOp>(
+            loc, rankedSourceType.getDimSize(i));
+      }
+      Value targetDim;
+      if (rankedTargetType.isDynamicDim(i)) {
+        targetDim = builder.create<DimOp>(loc, copyOp.getTarget(), i);
+      } else {
+        targetDim = builder.create<arith::ConstantIndexOp>(
+            loc, rankedTargetType.getDimSize(i));
+      }
+      Value sameDimSize = builder.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
+      builder.create<cf::AssertOp>(
+          loc, sameDimSize,
+          RuntimeVerifiableOpInterface::generateErrorMessage(
+              op, "size of " + std::to_string(i) +
+                      "-th source/target dim does not match"));
+    }
+  }
+};
+
 struct DimOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
                                                          DimOp> {
@@ -383,6 +430,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
     AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
     AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
     CastOp::attachInterface<CastOpInterface>(*ctx);
+    CopyOp::attachInterface<CopyOpInterface>(*ctx);
     DimOp::attachInterface<DimOpInterface>(*ctx);
     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
     GenericAtomicRMWOp::attachInterface<
diff --git a/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
new file mode 100644
index 0000000000000..95b9db2832cee
--- /dev/null
+++ b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -expand-strided-metadata \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-to-llvm | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+// Put memref.copy in a function, otherwise the memref.cast may fold.
+func.func @memcpy_helper(%src: memref<?xf32>, %dest: memref<?xf32>) {
+  memref.copy %src, %dest : memref<?xf32> to memref<?xf32>
+  return
+}
+
+func.func @main() {
+  %alloca1 = memref.alloca() : memref<4xf32>
+  %alloca2 = memref.alloca() : memref<5xf32>
+  %cast1 = memref.cast %alloca1 : memref<4xf32> to memref<?xf32>
+  %cast2 = memref.cast %alloca2 : memref<5xf32> to memref<?xf32>
+
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.copy"(%{{.*}}, %{{.*}}) : (memref<?xf32>, memref<?xf32>) -> ()
+  // CHECK-NEXT: ^ size of 0-th source/target dim does not match
+  // CHECK-NEXT: Location: loc({{.*}})
+  call @memcpy_helper(%cast1, %cast2) : (memref<?xf32>, memref<?xf32>) -> ()
+
+  return
+}



More information about the llvm-branch-commits mailing list