[llvm-branch-commits] [mlir] [mlir][memref] Add runtime verification for `memref.assume_alignment` (PR #130412)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sat Mar 8 04:47:36 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Implement runtime verification for `memref.assume_alignment`.


---
Full diff: https://github.com/llvm/llvm-project/pull/130412.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+24-1) 
- (added) mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir (+37) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index f825d7d9d42c2..c8e7325d7ac89 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -23,7 +23,7 @@ using namespace mlir;
 namespace mlir {
 namespace memref {
 namespace {
-/// Generate a runtime check for lb <= value < ub. 
+/// Generate a runtime check for lb <= value < ub.
 Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
                             Value lb, Value ub) {
   Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
@@ -35,6 +35,28 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
   return inBounds;
 }
 
+struct AssumeAlignmentOpInterface
+    : public RuntimeVerifiableOpInterface::ExternalModel<
+          AssumeAlignmentOpInterface, AssumeAlignmentOp> {
+  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+                                   Location loc) const {
+    auto assumeOp = cast<AssumeAlignmentOp>(op);
+    Value ptr = builder.create<ExtractAlignedPointerAsIndexOp>(
+        loc, assumeOp.getMemref());
+    Value rest = builder.create<arith::RemUIOp>(
+        loc, ptr,
+        builder.create<arith::ConstantIndexOp>(loc, assumeOp.getAlignment()));
+    Value isAligned = builder.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::eq, rest,
+        builder.create<arith::ConstantIndexOp>(loc, 0));
+    builder.create<cf::AssertOp>(
+        loc, isAligned,
+        RuntimeVerifiableOpInterface::generateErrorMessage(
+            op, "memref is not aligned to " +
+                    std::to_string(assumeOp.getAlignment())));
+  }
+};
+
 struct CastOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
                                                          CastOp> {
@@ -354,6 +376,7 @@ struct ExpandShapeOpInterface
 void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+    AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
     CastOp::attachInterface<CastOpInterface>(*ctx);
     DimOp::attachInterface<DimOpInterface>(*ctx);
     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
new file mode 100644
index 0000000000000..394648d1b8bfa
--- /dev/null
+++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -expand-strided-metadata \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-to-llvm | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @main() {
+  // This buffer is properly aligned. There should be no error.
+  // CHECK-NOT: ^ memref is not aligned to 8
+  %alloc = memref.alloca() : memref<5xf64>
+  memref.assume_alignment %alloc, 8 : memref<5xf64>
+
+  // Construct a memref descriptor with a pointer that is not aligned to 4.
+  // This cannot be done with just the memref dialect. We have to resort to
+  // the LLVM dialect.
+  %c0 = llvm.mlir.constant(0 : index) : i64
+  %c1 = llvm.mlir.constant(1 : index) : i64
+  %c3 = llvm.mlir.constant(3 : index) : i64
+  %unaligned_ptr = llvm.inttoptr %c3 : i64 to !llvm.ptr
+  %4 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %5 = llvm.insertvalue %unaligned_ptr, %4[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %6 = llvm.insertvalue %unaligned_ptr, %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %8 = llvm.insertvalue %c0, %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %9 = llvm.insertvalue %c1, %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %10 = llvm.insertvalue %c1, %9[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %buffer = builtin.unrealized_conversion_cast %10 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<1xf32>
+
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.assume_alignment"(%{{.*}}) <{alignment = 4 : i32}> : (memref<1xf32>) -> ()
+  // CHECK-NEXT: ^ memref is not aligned to 4
+  // CHECK-NEXT: Location: loc({{.*}})
+  memref.assume_alignment %buffer, 4 : memref<1xf32>
+
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/130412


More information about the llvm-branch-commits mailing list