[Mlir-commits] [mlir] [mlir][memref] Add runtime verification for `memref.assume_alignment` (PR #130412)

Matthias Springer llvmlistbot at llvm.org
Tue Mar 18 02:20:03 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/130412

>From d6b573e147b1c06782c1bce651975a3b27a8eeb9 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 8 Mar 2025 13:45:45 +0100
Subject: [PATCH] [mlir][memref] Add runtime verification for
 `memref.assume_alignment`

---
 .../Transforms/RuntimeOpVerification.cpp      | 23 ++++++++++++
 ...assume-alignment-runtime-verification.mlir | 37 +++++++++++++++++++
 2 files changed, 60 insertions(+)
 create mode 100644 mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir

diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 3fd561de3b5e6..134e8b5efcfdf 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -35,6 +35,28 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
   return inBounds;
 }
 
+struct AssumeAlignmentOpInterface
+    : public RuntimeVerifiableOpInterface::ExternalModel<
+          AssumeAlignmentOpInterface, AssumeAlignmentOp> {
+  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+                                   Location loc) const {
+    auto assumeOp = cast<AssumeAlignmentOp>(op);
+    Value ptr = builder.create<ExtractAlignedPointerAsIndexOp>(
+        loc, assumeOp.getMemref());
+    Value rest = builder.create<arith::RemUIOp>(
+        loc, ptr,
+        builder.create<arith::ConstantIndexOp>(loc, assumeOp.getAlignment()));
+    Value isAligned = builder.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::eq, rest,
+        builder.create<arith::ConstantIndexOp>(loc, 0));
+    builder.create<cf::AssertOp>(
+        loc, isAligned,
+        RuntimeVerifiableOpInterface::generateErrorMessage(
+            op, "memref is not aligned to " +
+                    std::to_string(assumeOp.getAlignment())));
+  }
+};
+
 struct CastOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
                                                          CastOp> {
@@ -398,6 +420,7 @@ struct ExpandShapeOpInterface
 void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+    AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
     CastOp::attachInterface<CastOpInterface>(*ctx);
     CopyOp::attachInterface<CopyOpInterface>(*ctx);
     DimOp::attachInterface<DimOpInterface>(*ctx);
diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
new file mode 100644
index 0000000000000..394648d1b8bfa
--- /dev/null
+++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -expand-strided-metadata \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-to-llvm | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @main() {
+  // This buffer is properly aligned. There should be no error.
+  // CHECK-NOT: ^ memref is not aligned to 8
+  %alloc = memref.alloca() : memref<5xf64>
+  memref.assume_alignment %alloc, 8 : memref<5xf64>
+
+  // Construct a memref descriptor with a pointer that is not aligned to 4.
+  // This cannot be done with just the memref dialect. We have to resort to
+  // the LLVM dialect.
+  %c0 = llvm.mlir.constant(0 : index) : i64
+  %c1 = llvm.mlir.constant(1 : index) : i64
+  %c3 = llvm.mlir.constant(3 : index) : i64
+  %unaligned_ptr = llvm.inttoptr %c3 : i64 to !llvm.ptr
+  %4 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %5 = llvm.insertvalue %unaligned_ptr, %4[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %6 = llvm.insertvalue %unaligned_ptr, %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %8 = llvm.insertvalue %c0, %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %9 = llvm.insertvalue %c1, %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %10 = llvm.insertvalue %c1, %9[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %buffer = builtin.unrealized_conversion_cast %10 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<1xf32>
+
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.assume_alignment"(%{{.*}}) <{alignment = 4 : i32}> : (memref<1xf32>) -> ()
+  // CHECK-NEXT: ^ memref is not aligned to 4
+  // CHECK-NEXT: Location: loc({{.*}})
+  memref.assume_alignment %buffer, 4 : memref<1xf32>
+
+  return
+}



More information about the Mlir-commits mailing list