[Mlir-commits] [mlir] [mlir][memref] Add runtime verification for `memref.dim` (PR #130410)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 8 04:06:07 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Add runtime verification for `memref.dim`: check that the index is in bounds.

Also simplify the pass pipeline for all memref runtime verification checks.


---
Full diff: https://github.com/llvm/llvm-project/pull/130410.diff


7 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+33-12) 
- (modified) mlir/lib/Transforms/GenerateRuntimeVerification.cpp (+10-2) 
- (modified) mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir (+3-4) 
- (added) mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir (+20) 
- (modified) mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir (+3-5) 
- (modified) mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir (+3-5) 
- (modified) mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir (+2-5) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index f93ae0a7a298f..f825d7d9d42c2 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -23,6 +23,18 @@ using namespace mlir;
 namespace mlir {
 namespace memref {
 namespace {
+/// Generate a runtime check for lb <= value < ub. 
+Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
+                            Value lb, Value ub) {
+  Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
+      loc, arith::CmpIPredicate::sge, value, lb);
+  Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
+      loc, arith::CmpIPredicate::slt, value, ub);
+  Value inBounds =
+      builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
+  return inBounds;
+}
+
 struct CastOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
                                                          CastOp> {
@@ -128,6 +140,21 @@ struct CastOpInterface
   }
 };
 
+struct DimOpInterface
+    : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
+                                                         DimOp> {
+  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+                                   Location loc) const {
+    auto dimOp = cast<DimOp>(op);
+    Value rank = builder.create<RankOp>(loc, dimOp.getSource());
+    Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    builder.create<cf::AssertOp>(
+        loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
+        RuntimeVerifiableOpInterface::generateErrorMessage(
+            op, "index is out of bounds"));
+  }
+};
+
 /// Verifies that the indices on load/store ops are in-bounds of the memref's
 /// index space: 0 <= index#i < dim#i
 template <typename LoadStoreOp>
@@ -148,19 +175,12 @@ struct LoadStoreOpInterface
     auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
     Value assertCond;
     for (auto i : llvm::seq<int64_t>(0, rank)) {
-      auto index = indices[i];
-
-      auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
-
-      auto geLow = builder.createOrFold<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::sge, index, zero);
-      auto ltHigh = builder.createOrFold<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::slt, index, dimOp);
-      auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
-
+      Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
+      Value inBounds =
+          generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
       assertCond =
-          i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
-                : andOp;
+          i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
+                : inBounds;
     }
     builder.create<cf::AssertOp>(
         loc, assertCond,
@@ -335,6 +355,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
     CastOp::attachInterface<CastOpInterface>(*ctx);
+    DimOp::attachInterface<DimOpInterface>(*ctx);
     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
     LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
     ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
index 62db9ce1316ae..a40bc2b3272fc 100644
--- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
+++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
@@ -28,11 +28,19 @@ struct GenerateRuntimeVerificationPass
 } // namespace
 
 void GenerateRuntimeVerificationPass::runOnOperation() {
+  // The implementation of the RuntimeVerifiableOpInterface may create ops that
+  // can be verified. We don't want to generate verification for IR that
+  // performs verification, so gather all runtime-verifiable ops first.
+  SmallVector<RuntimeVerifiableOpInterface> ops;
   getOperation()->walk([&](RuntimeVerifiableOpInterface verifiableOp) {
-    OpBuilder builder(getOperation()->getContext());
+    ops.push_back(verifiableOp);
+  });
+
+  OpBuilder builder(getOperation()->getContext());
+  for (RuntimeVerifiableOpInterface verifiableOp : ops) {
     builder.setInsertionPoint(verifiableOp);
     verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
-  });
+  };
 }
 
 std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass() {
diff --git a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
index b101a875154ff..8b6308e9c1939 100644
--- a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
@@ -1,8 +1,7 @@
-// RUN: mlir-opt %s -generate-runtime-verification -finalize-memref-to-llvm \
+// RUN: mlir-opt %s -generate-runtime-verification \
 // RUN:     -test-cf-assert \
-// RUN:     -convert-func-to-llvm \
-// RUN:     -convert-arith-to-llvm \
-// RUN:     -reconcile-unrealized-casts | \
+// RUN:     -expand-strided-metadata \
+// RUN:     -convert-to-llvm | \
 // RUN: mlir-runner -e main -entry-point-result=void \
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir
new file mode 100644
index 0000000000000..2e3f271743c93
--- /dev/null
+++ b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -expand-strided-metadata \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-to-llvm | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @main() {
+  %c4 = arith.constant 4 : index
+  %alloca = memref.alloca() : memref<1xf32>
+
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.dim"(%{{.*}}, %{{.*}}) : (memref<1xf32>, index) -> index
+  // CHECK-NEXT: ^ index is out of bounds
+  // CHECK-NEXT: Location: loc({{.*}})
+  %dim = memref.dim %alloca, %c4 : memref<1xf32>
+
+  return
+}
diff --git a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
index d6c5d6da0041e..b87e5bdf0970c 100644
--- a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
@@ -1,10 +1,8 @@
 // RUN: mlir-opt %s -generate-runtime-verification \
-// RUN:     -expand-strided-metadata \
-// RUN:     -finalize-memref-to-llvm \
 // RUN:     -test-cf-assert \
-// RUN:     -convert-func-to-llvm \
-// RUN:     -convert-arith-to-llvm \
-// RUN:     -reconcile-unrealized-casts | \
+// RUN:     -expand-strided-metadata \
+// RUN:     -lower-affine \
+// RUN:     -convert-to-llvm | \
 // RUN: mlir-runner -e main -entry-point-result=void \
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
index 9fea48bdfc07d..601a53f4b5cd9 100644
--- a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
@@ -1,10 +1,8 @@
 // RUN: mlir-opt %s -generate-runtime-verification \
-// RUN:     -lower-affine \
-// RUN:     -finalize-memref-to-llvm \
 // RUN:     -test-cf-assert \
-// RUN:     -convert-func-to-llvm \
-// RUN:     -convert-arith-to-llvm \
-// RUN:     -reconcile-unrealized-casts | \
+// RUN:     -expand-strided-metadata \
+// RUN:     -lower-affine \
+// RUN:     -convert-to-llvm | \
 // RUN: mlir-runner -e main -entry-point-result=void \
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
index 66474e9c4ae37..3cac37a082c30 100644
--- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
@@ -1,11 +1,8 @@
 // RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -test-cf-assert \
 // RUN:     -expand-strided-metadata \
 // RUN:     -lower-affine \
-// RUN:     -finalize-memref-to-llvm \
-// RUN:     -test-cf-assert \
-// RUN:     -convert-func-to-llvm \
-// RUN:     -convert-arith-to-llvm \
-// RUN:     -reconcile-unrealized-casts | \
+// RUN:     -convert-to-llvm | \
 // RUN: mlir-runner -e main -entry-point-result=void \
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s

``````````

</details>


https://github.com/llvm/llvm-project/pull/130410


More information about the Mlir-commits mailing list