[Mlir-commits] [mlir] [mlir][memref] Add runtime verification for `memref.dim` (PR #130410)
Matthias Springer
llvmlistbot at llvm.org
Sat Mar 8 04:05:37 PST 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/130410
Add runtime verification for `memref.dim`: check that the index is in bounds.
Also simplify the pass pipeline for all memref runtime verification checks.
>From 68839721b4d5d7ef8fc456767521a169a5fd9307 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 8 Mar 2025 13:03:49 +0100
Subject: [PATCH] [mlir][memref] Add runtime verification for `memref.dim`
---
.../Transforms/RuntimeOpVerification.cpp | 45 ++++++++++++++-----
.../GenerateRuntimeVerification.cpp | 12 ++++-
.../MemRef/cast-runtime-verification.mlir | 7 ++-
.../MemRef/dim-runtime-verification.mlir | 20 +++++++++
.../MemRef/load-runtime-verification.mlir | 8 ++--
...reinterpret-cast-runtime-verification.mlir | 8 ++--
.../MemRef/subview-runtime-verification.mlir | 7 +--
7 files changed, 74 insertions(+), 33 deletions(-)
create mode 100644 mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index f93ae0a7a298f..f825d7d9d42c2 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -23,6 +23,18 @@ using namespace mlir;
namespace mlir {
namespace memref {
namespace {
+/// Generate a runtime check for lb <= value < ub.
+Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
+ Value lb, Value ub) {
+ Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, value, lb);
+ Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::slt, value, ub);
+ Value inBounds =
+ builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
+ return inBounds;
+}
+
struct CastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
CastOp> {
@@ -128,6 +140,21 @@ struct CastOpInterface
}
};
+struct DimOpInterface
+ : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
+ DimOp> {
+ void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+ Location loc) const {
+ auto dimOp = cast<DimOp>(op);
+ Value rank = builder.create<RankOp>(loc, dimOp.getSource());
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ builder.create<cf::AssertOp>(
+ loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "index is out of bounds"));
+ }
+};
+
/// Verifies that the indices on load/store ops are in-bounds of the memref's
/// index space: 0 <= index#i < dim#i
template <typename LoadStoreOp>
@@ -148,19 +175,12 @@ struct LoadStoreOpInterface
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value assertCond;
for (auto i : llvm::seq<int64_t>(0, rank)) {
- auto index = indices[i];
-
- auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
-
- auto geLow = builder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, index, zero);
- auto ltHigh = builder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, index, dimOp);
- auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
-
+ Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
+ Value inBounds =
+ generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
assertCond =
- i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
- : andOp;
+ i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
+ : inBounds;
}
builder.create<cf::AssertOp>(
loc, assertCond,
@@ -335,6 +355,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
CastOp::attachInterface<CastOpInterface>(*ctx);
+ DimOp::attachInterface<DimOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
index 62db9ce1316ae..a40bc2b3272fc 100644
--- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
+++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
@@ -28,11 +28,19 @@ struct GenerateRuntimeVerificationPass
} // namespace
void GenerateRuntimeVerificationPass::runOnOperation() {
+ // The implementation of the RuntimeVerifiableOpInterface may create ops that
+ // can be verified. We don't want to generate verification for IR that
+ // performs verification, so gather all runtime-verifiable ops first.
+ SmallVector<RuntimeVerifiableOpInterface> ops;
getOperation()->walk([&](RuntimeVerifiableOpInterface verifiableOp) {
- OpBuilder builder(getOperation()->getContext());
+ ops.push_back(verifiableOp);
+ });
+
+ OpBuilder builder(getOperation()->getContext());
+ for (RuntimeVerifiableOpInterface verifiableOp : ops) {
builder.setInsertionPoint(verifiableOp);
verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
- });
+ };
}
std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass() {
diff --git a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
index b101a875154ff..8b6308e9c1939 100644
--- a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
@@ -1,8 +1,7 @@
-// RUN: mlir-opt %s -generate-runtime-verification -finalize-memref-to-llvm \
+// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
-// RUN: -convert-func-to-llvm \
-// RUN: -convert-arith-to-llvm \
-// RUN: -reconcile-unrealized-casts | \
+// RUN: -expand-strided-metadata \
+// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir
new file mode 100644
index 0000000000000..2e3f271743c93
--- /dev/null
+++ b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -expand-strided-metadata \
+// RUN: -test-cf-assert \
+// RUN: -convert-to-llvm | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @main() {
+ %c4 = arith.constant 4 : index
+ %alloca = memref.alloca() : memref<1xf32>
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.dim"(%{{.*}}, %{{.*}}) : (memref<1xf32>, index) -> index
+ // CHECK-NEXT: ^ index is out of bounds
+ // CHECK-NEXT: Location: loc({{.*}})
+ %dim = memref.dim %alloca, %c4 : memref<1xf32>
+
+ return
+}
diff --git a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
index d6c5d6da0041e..b87e5bdf0970c 100644
--- a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
@@ -1,10 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
-// RUN: -expand-strided-metadata \
-// RUN: -finalize-memref-to-llvm \
// RUN: -test-cf-assert \
-// RUN: -convert-func-to-llvm \
-// RUN: -convert-arith-to-llvm \
-// RUN: -reconcile-unrealized-casts | \
+// RUN: -expand-strided-metadata \
+// RUN: -lower-affine \
+// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
index 9fea48bdfc07d..601a53f4b5cd9 100644
--- a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
@@ -1,10 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
-// RUN: -lower-affine \
-// RUN: -finalize-memref-to-llvm \
// RUN: -test-cf-assert \
-// RUN: -convert-func-to-llvm \
-// RUN: -convert-arith-to-llvm \
-// RUN: -reconcile-unrealized-casts | \
+// RUN: -expand-strided-metadata \
+// RUN: -lower-affine \
+// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
index 66474e9c4ae37..3cac37a082c30 100644
--- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
@@ -1,11 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -test-cf-assert \
// RUN: -expand-strided-metadata \
// RUN: -lower-affine \
-// RUN: -finalize-memref-to-llvm \
-// RUN: -test-cf-assert \
-// RUN: -convert-func-to-llvm \
-// RUN: -convert-arith-to-llvm \
-// RUN: -reconcile-unrealized-casts | \
+// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
More information about the Mlir-commits
mailing list