[llvm-branch-commits] [mlir] [mlir][memref] Add runtime verification for `memref.assume_alignment` (PR #130412)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sat Mar 8 04:47:36 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Implement runtime verification for `memref.assume_alignment`.
---
Full diff: https://github.com/llvm/llvm-project/pull/130412.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+24-1)
- (added) mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir (+37)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index f825d7d9d42c2..c8e7325d7ac89 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -23,7 +23,7 @@ using namespace mlir;
namespace mlir {
namespace memref {
namespace {
-/// Generate a runtime check for lb <= value < ub.
+/// Generate a runtime check for lb <= value < ub.
Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
Value lb, Value ub) {
Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
@@ -35,6 +35,28 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
return inBounds;
}
+struct AssumeAlignmentOpInterface
+ : public RuntimeVerifiableOpInterface::ExternalModel<
+ AssumeAlignmentOpInterface, AssumeAlignmentOp> {
+ void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+ Location loc) const {
+ auto assumeOp = cast<AssumeAlignmentOp>(op);
+ Value ptr = builder.create<ExtractAlignedPointerAsIndexOp>(
+ loc, assumeOp.getMemref());
+ Value rest = builder.create<arith::RemUIOp>(
+ loc, ptr,
+ builder.create<arith::ConstantIndexOp>(loc, assumeOp.getAlignment()));
+ Value isAligned = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, rest,
+ builder.create<arith::ConstantIndexOp>(loc, 0));
+ builder.create<cf::AssertOp>(
+ loc, isAligned,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "memref is not aligned to " +
+ std::to_string(assumeOp.getAlignment())));
+ }
+};
+
struct CastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
CastOp> {
@@ -354,6 +376,7 @@ struct ExpandShapeOpInterface
void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+ AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
CastOp::attachInterface<CastOpInterface>(*ctx);
DimOp::attachInterface<DimOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
new file mode 100644
index 0000000000000..394648d1b8bfa
--- /dev/null
+++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -expand-strided-metadata \
+// RUN: -test-cf-assert \
+// RUN: -convert-to-llvm | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @main() {
+ // This buffer is properly aligned. There should be no error.
+ // CHECK-NOT: ^ memref is not aligned to 8
+ %alloc = memref.alloca() : memref<5xf64>
+ memref.assume_alignment %alloc, 8 : memref<5xf64>
+
+ // Construct a memref descriptor with a pointer that is not aligned to 4.
+ // This cannot be done with just the memref dialect. We have to resort to
+ // the LLVM dialect.
+ %c0 = llvm.mlir.constant(0 : index) : i64
+ %c1 = llvm.mlir.constant(1 : index) : i64
+ %c3 = llvm.mlir.constant(3 : index) : i64
+ %unaligned_ptr = llvm.inttoptr %c3 : i64 to !llvm.ptr
+ %4 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %5 = llvm.insertvalue %unaligned_ptr, %4[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %6 = llvm.insertvalue %unaligned_ptr, %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %8 = llvm.insertvalue %c0, %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %9 = llvm.insertvalue %c1, %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %10 = llvm.insertvalue %c1, %9[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %buffer = builtin.unrealized_conversion_cast %10 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<1xf32>
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.assume_alignment"(%{{.*}}) <{alignment = 4 : i32}> : (memref<1xf32>) -> ()
+ // CHECK-NEXT: ^ memref is not aligned to 4
+ // CHECK-NEXT: Location: loc({{.*}})
+ memref.assume_alignment %buffer, 4 : memref<1xf32>
+
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/130412
More information about the llvm-branch-commits
mailing list