[Mlir-commits] [mlir] [mlir][memref] Add runtime verification for `memref.atomic_rmw` (PR #130414)
Matthias Springer
llvmlistbot at llvm.org
Thu Mar 20 00:31:22 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/130414
>From 18f917b8c68bf7164b893d274c484d19769ea712 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 8 Mar 2025 14:34:54 +0100
Subject: [PATCH] [mlir][memref] Add runtime verification for
`memref.atomic_rmw`
---
.../Transforms/RuntimeOpVerification.cpp | 45 +++++++++++--------
.../atomic-rmw-runtime-verification.mlir | 45 +++++++++++++++++++
.../MemRef/store-runtime-verification.mlir | 45 +++++++++++++++++++
3 files changed, 116 insertions(+), 19 deletions(-)
create mode 100644 mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
create mode 100644 mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 134e8b5efcfdf..be0f4724ea63a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -35,6 +35,26 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
return inBounds;
}
+/// Generate a runtime check to see if the given indices are in-bounds with
+/// respect to the given ranked memref.
+Value generateIndicesInBoundsCheck(OpBuilder &builder, Location loc,
+ Value memref, ValueRange indices) {
+ auto memrefType = cast<MemRefType>(memref.getType());
+ assert(memrefType.getRank() == static_cast<int64_t>(indices.size()) &&
+ "rank mismatch");
+ Value cond = builder.create<arith::ConstantOp>(
+ loc, builder.getIntegerAttr(builder.getI1Type(), 1));
+
+ auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ for (auto [dim, idx] : llvm::enumerate(indices)) {
+ Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, dim);
+ Value inBounds = generateInBoundsCheck(builder, loc, idx, zero, dimOp);
+ cond = builder.createOrFold<arith::AndIOp>(loc, cond, inBounds);
+ }
+
+ return cond;
+}
+
struct AssumeAlignmentOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
AssumeAlignmentOpInterface, AssumeAlignmentOp> {
@@ -230,26 +250,10 @@ struct LoadStoreOpInterface
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto loadStoreOp = cast<LoadStoreOp>(op);
-
- auto memref = loadStoreOp.getMemref();
- auto rank = memref.getType().getRank();
- if (rank == 0) {
- return;
- }
- auto indices = loadStoreOp.getIndices();
-
- auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
- Value assertCond;
- for (auto i : llvm::seq<int64_t>(0, rank)) {
- Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
- Value inBounds =
- generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
- assertCond =
- i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
- : inBounds;
- }
+ Value cond = generateIndicesInBoundsCheck(
+ builder, loc, loadStoreOp.getMemref(), loadStoreOp.getIndices());
builder.create<cf::AssertOp>(
- loc, assertCond,
+ loc, cond,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "out-of-bounds access"));
}
@@ -421,10 +425,13 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
+ AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
CastOp::attachInterface<CastOpInterface>(*ctx);
CopyOp::attachInterface<CopyOpInterface>(*ctx);
DimOp::attachInterface<DimOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
+ GenericAtomicRMWOp::attachInterface<
+ LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
diff --git a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
new file mode 100644
index 0000000000000..9f70c5ca66f65
--- /dev/null
+++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -test-cf-assert \
+// RUN: -expand-strided-metadata \
+// RUN: -lower-affine \
+// RUN: -convert-to-llvm | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @store_dynamic(%memref: memref<?xf32>, %index: index) {
+ %cst = arith.constant 1.0 : f32
+ memref.atomic_rmw addf %cst, %memref[%index] : (f32, memref<?xf32>) -> f32
+ return
+}
+
+func.func @main() {
+ // Allocate a memref<10xf32>, but disguise it as a memref<5xf32>. This is
+ // necessary because "-test-cf-assert" does not abort the program and we do
+ // not want to segfault when running the test case.
+ %alloc = memref.alloca() : memref<10xf32>
+ %ptr = memref.extract_aligned_pointer_as_index %alloc : memref<10xf32> -> index
+ %ptr_i64 = arith.index_cast %ptr : index to i64
+ %ptr_llvm = llvm.inttoptr %ptr_i64 : i64 to !llvm.ptr
+ %c0 = llvm.mlir.constant(0 : index) : i64
+ %c1 = llvm.mlir.constant(1 : index) : i64
+ %c5 = llvm.mlir.constant(5 : index) : i64
+ %4 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %5 = llvm.insertvalue %ptr_llvm, %4[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %6 = llvm.insertvalue %ptr_llvm, %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %8 = llvm.insertvalue %c0, %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %9 = llvm.insertvalue %c5, %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %10 = llvm.insertvalue %c1, %9[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %buffer = builtin.unrealized_conversion_cast %10 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<5xf32>
+ %cast = memref.cast %buffer : memref<5xf32> to memref<?xf32>
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.atomic_rmw"(%{{.*}}, %{{.*}}, %{{.*}}) <{kind = 0 : i64}> : (f32, memref<?xf32>, index) -> f32
+ // CHECK-NEXT: ^ out-of-bounds access
+ // CHECK-NEXT: Location: loc({{.*}})
+ %c9 = arith.constant 9 : index
+ func.call @store_dynamic(%cast, %c9) : (memref<?xf32>, index) -> ()
+
+ return
+}
+
diff --git a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
new file mode 100644
index 0000000000000..58961ba31d93a
--- /dev/null
+++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -test-cf-assert \
+// RUN: -expand-strided-metadata \
+// RUN: -lower-affine \
+// RUN: -convert-to-llvm | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @store_dynamic(%memref: memref<?xf32>, %index: index) {
+ %cst = arith.constant 1.0 : f32
+ memref.store %cst, %memref[%index] : memref<?xf32>
+ return
+}
+
+func.func @main() {
+ // Allocate a memref<10xf32>, but disguise it as a memref<5xf32>. This is
+ // necessary because "-test-cf-assert" does not abort the program and we do
+ // not want to segfault when running the test case.
+ %alloc = memref.alloca() : memref<10xf32>
+ %ptr = memref.extract_aligned_pointer_as_index %alloc : memref<10xf32> -> index
+ %ptr_i64 = arith.index_cast %ptr : index to i64
+ %ptr_llvm = llvm.inttoptr %ptr_i64 : i64 to !llvm.ptr
+ %c0 = llvm.mlir.constant(0 : index) : i64
+ %c1 = llvm.mlir.constant(1 : index) : i64
+ %c5 = llvm.mlir.constant(5 : index) : i64
+ %4 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %5 = llvm.insertvalue %ptr_llvm, %4[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %6 = llvm.insertvalue %ptr_llvm, %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %8 = llvm.insertvalue %c0, %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %9 = llvm.insertvalue %c5, %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %10 = llvm.insertvalue %c1, %9[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %buffer = builtin.unrealized_conversion_cast %10 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<5xf32>
+ %cast = memref.cast %buffer : memref<5xf32> to memref<?xf32>
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.store"(%{{.*}}, %{{.*}}, %{{.*}}) : (f32, memref<?xf32>, index) -> ()
+ // CHECK-NEXT: ^ out-of-bounds access
+ // CHECK-NEXT: Location: loc({{.*}})
+ %c9 = arith.constant 9 : index
+ func.call @store_dynamic(%cast, %c9) : (memref<?xf32>, index) -> ()
+
+ return
+}
+
More information about the Mlir-commits
mailing list