[Mlir-commits] [mlir] Fixed assume alignment lit test (PR #140446)

Shay Kleiman llvmlistbot at llvm.org
Sun May 18 04:49:36 PDT 2025


https://github.com/shay-kl created https://github.com/llvm/llvm-project/pull/140446

None

>From bddbbe9a382963a94cf1d598b66111b8658a0048 Mon Sep 17 00:00:00 2001
From: Shay Kleiman <shayk at epgd045.me-corp.lan>
Date: Sun, 11 May 2025 16:22:44 +0300
Subject: [PATCH 1/6] Added MemoryEffect to assume_alignment

Assume_alignment has no trait which specifies how it interacts with
memory, this causes an issue in OwnershipBasedBufferDeallocation,
which require all operations which operate on buffers to have explicit
traits defining how the operation interacts with memory.

>From my understanding, technically the operation is pure, however to make
sure the operation doesn't get optimized away it has to have some side
effect. I defined it to have similar side effects to CF AssertOp as both
are asserts. I'm not sure if this is correct and would appreciate the
opinion of someone more experienced.
---
 mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td           | 2 +-
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp                   | 5 +++++
 .../OwnershipBasedBufferDeallocation/misc-other.mlir       | 7 +++++++
 3 files changed, 13 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d6d8161d3117b..856b033f401a0 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -142,7 +142,7 @@ class AllocLikeOp<string mnemonic,
 // AssumeAlignmentOp
 //===----------------------------------------------------------------------===//
 
-def AssumeAlignmentOp : MemRef_Op<"assume_alignment"> {
+def AssumeAlignmentOp : MemRef_Op<"assume_alignment",[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let summary =
       "assertion that gives alignment information to the input memref";
   let description = [{
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a0237c18cf2fe..1872a63f93c93 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -526,6 +526,11 @@ LogicalResult AssumeAlignmentOp::verify() {
     return emitOpError("alignment must be power of 2");
   return success();
 }
+void AssumeAlignmentOp::getEffects(
+  SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+      &effects) {
+effects.emplace_back(MemoryEffects::Write::get());
+}
 
 //===----------------------------------------------------------------------===//
 // CastOp
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/misc-other.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/misc-other.mlir
index 05e52848ca877..4e0e743fc3ed9 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/misc-other.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/misc-other.mlir
@@ -10,4 +10,11 @@ func.func @func_with_assert(%arg0: index, %arg1: index) {
   %0 = arith.cmpi slt, %arg0, %arg1 : index
   cf.assert %0, "%arg0 must be less than %arg1"
   return
+}
+
+// CHECK-LABEL: func @func_with_assume_alignment(
+//       CHECK: memref.assume_alignment %arg0, 64 : memref<128xi8>
+func.func @func_with_assume_alignment(%arg0: memref<128xi8>) {
+  memref.assume_alignment %arg0, 64 : memref<128xi8>
+  return
 }
\ No newline at end of file

>From e903f6d419772c52bfee7d59f4e7a5f8f0c40ef2 Mon Sep 17 00:00:00 2001
From: Shay Kleiman <shayk at epgd045.me-corp.lan>
Date: Wed, 14 May 2025 15:09:32 +0300
Subject: [PATCH 2/6] Changed AssumeAlignment into a ViewLikeOp

Made AssumeAlignment a ViewLikeOp that returns a new SSA memref equal
to its memref argument and made it have NoMemoryEffect trait. This
gives it a defined memory effect that matches what it does in practice
and makes it behave nicely with optimizations which won't get rid of it
unless its result isn't being used.
---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       | 26 +++++++++++-----
 .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp  |  3 +-
 .../GPU/Transforms/EliminateBarriers.cpp      | 10 +------
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  5 ----
 .../MemRef/Transforms/EmulateNarrowType.cpp   |  2 +-
 .../Transforms/ExpandStridedMetadata.cpp      | 30 +++++++++++++++++++
 .../expand-then-convert-to-llvm.mlir          |  4 +--
 .../misc-other.mlir                           |  4 +--
 .../Dialect/MemRef/emulate-narrow-type.mlir   | 28 ++++++++---------
 mlir/test/Dialect/MemRef/invalid.mlir         |  4 +--
 mlir/test/Dialect/MemRef/ops.mlir             |  2 +-
 11 files changed, 73 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 856b033f401a0..1d07e4ca7fe84 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -142,22 +142,34 @@ class AllocLikeOp<string mnemonic,
 // AssumeAlignmentOp
 //===----------------------------------------------------------------------===//
 
-def AssumeAlignmentOp : MemRef_Op<"assume_alignment",[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
+      NoMemoryEffect, 
+      ViewLikeOpInterface,
+      SameOperandsAndResultType
+    ]> {
   let summary =
       "assertion that gives alignment information to the input memref";
   let description = [{
-    The `assume_alignment` operation takes a memref and an integer of alignment
-    value, and internally annotates the buffer with the given alignment. If
-    the buffer isn't aligned to the given alignment, the behavior is undefined.
+      The `assume_alignment` operation takes a memref and an integer of alignment
+      value. It returns a new SSA value of the same memref type, but associated
+      with the assertion that the underlying buffer is aligned to the given
+      alignment. If the buffer isn't aligned to the given alignment, the 
+      behavior is undefined.
 
-    This operation doesn't affect the semantics of a correct program. It's for
-    optimization only, and the optimization is best-effort.
+      This operation doesn't affect the semantics of a correct program. It's for
+      optimization only, and the optimization is best-effort.
   }];
   let arguments = (ins AnyMemRef:$memref,
                        ConfinedAttr<I32Attr, [IntPositive]>:$alignment);
-  let results = (outs);
+  let results = (outs AnyMemRef:$result);
 
   let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
+  let extraClassDeclaration = [{
+    MemRefType getType() { return ::llvm::cast<MemRefType>(getResult().getType()); }
+    
+    Value getViewSource() { return getMemref(); }
+  }];
+
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index c8b2c0bdc6c20..2d9adcf20c209 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -432,8 +432,7 @@ struct AssumeAlignmentOpLowering
         createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
     rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
                                     alignmentConst);
-
-    rewriter.eraseOp(op);
+    rewriter.replaceOp(op, memref);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
index 84c12c0ba05e5..43e04232ffe7d 100644
--- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
@@ -44,12 +44,7 @@ using namespace mlir::gpu;
 // The functions below provide interface-like verification, but are too specific
 // to barrier elimination to become interfaces.
 
-/// Implement the MemoryEffectsOpInterface in the suitable way.
-static bool isKnownNoEffectsOpWithoutInterface(Operation *op) {
-  // memref::AssumeAlignment is conceptually pure, but marking it as such would
-  // make DCE immediately remove it.
-  return isa<memref::AssumeAlignmentOp>(op);
-}
+
 
 /// Returns `true` if the op is defines the parallel region that is subject to
 /// barrier synchronization.
@@ -101,9 +96,6 @@ collectEffects(Operation *op,
   if (ignoreBarriers && isa<BarrierOp>(op))
     return true;
 
-  // Skip over ops that we know have no effects.
-  if (isKnownNoEffectsOpWithoutInterface(op))
-    return true;
 
   // Collect effect instances the operation. Note that the implementation of
   // getEffects erases all effect instances that have the type other than the
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1872a63f93c93..a0237c18cf2fe 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -526,11 +526,6 @@ LogicalResult AssumeAlignmentOp::verify() {
     return emitOpError("alignment must be power of 2");
   return success();
 }
-void AssumeAlignmentOp::getEffects(
-  SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
-      &effects) {
-effects.emplace_back(MemoryEffects::Write::get());
-}
 
 //===----------------------------------------------------------------------===//
 // CastOp
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 59cfce28e07e1..d2a032688fb6d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -229,7 +229,7 @@ struct ConvertMemRefAssumeAlignment final
     }
 
     rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
-        op, adaptor.getMemref(), adaptor.getAlignmentAttr());
+        op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr());
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index e9a80be87a0f7..60838696df4ad 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -919,6 +919,34 @@ struct ExtractStridedMetadataOpGetGlobalFolder
   }
 };
 
+/// Pattern to replace `extract_strided_metadata(assume_alignment)`
+///
+/// With
+/// \verbatim
+/// extract_strided_metadata(memref)
+/// \endverbatim
+///
+/// Since `assume_alignment` is a view-like op that does not modify the
+/// underlying buffer, offset, sizes, or strides, extracting strided metadata
+/// from its result is equivalent to extracting it from its source. This
+/// canonicalization removes the unnecessary indirection.
+struct ExtractStridedMetadataOpAssumeAlignmentFolder
+    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+public:
+  using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+                                PatternRewriter &rewriter) const override {
+    auto assumeAlignmentOp = 
+        op.getSource().getDefiningOp<memref::AssumeAlignmentOp>();
+    if (!assumeAlignmentOp)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<memref::ExtractStridedMetadataOp>(op, assumeAlignmentOp.getViewSource());
+    return success();
+  }
+};
+
 /// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the
 /// source of the ViewLikeOp.
 class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
@@ -1185,6 +1213,7 @@ void memref::populateExpandStridedMetadataPatterns(
                ExtractStridedMetadataOpSubviewFolder,
                ExtractStridedMetadataOpCastFolder,
                ExtractStridedMetadataOpMemorySpaceCastFolder,
+               ExtractStridedMetadataOpAssumeAlignmentFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
       patterns.getContext());
 }
@@ -1201,6 +1230,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
                ExtractStridedMetadataOpReinterpretCastFolder,
                ExtractStridedMetadataOpCastFolder,
                ExtractStridedMetadataOpMemorySpaceCastFolder,
+               ExtractStridedMetadataOpAssumeAlignmentFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
       patterns.getContext());
 }
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index fe91d26d5a251..8dd7edf3e29b1 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -683,7 +683,7 @@ func.func @load_and_assume(
     %arg0: memref<?x?xf32, strided<[?, ?], offset: ?>>,
     %i0: index, %i1: index)
     -> f32 {
-  memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
-  %2 = memref.load %arg0[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+  %arg0_align = memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
+  %2 = memref.load %arg0_align[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
   func.return %2 : f32
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/misc-other.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/misc-other.mlir
index 4e0e743fc3ed9..c50c25ad8194f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/misc-other.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/misc-other.mlir
@@ -13,8 +13,8 @@ func.func @func_with_assert(%arg0: index, %arg1: index) {
 }
 
 // CHECK-LABEL: func @func_with_assume_alignment(
-//       CHECK: memref.assume_alignment %arg0, 64 : memref<128xi8>
+//       CHECK: %0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
 func.func @func_with_assume_alignment(%arg0: memref<128xi8>) {
-  memref.assume_alignment %arg0, 64 : memref<128xi8>
+  %0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
   return
 }
\ No newline at end of file
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 0cb3b7b744476..111a02abcc74c 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -63,8 +63,8 @@ func.func @memref_load_i4(%arg0: index) -> i4 {
 
 func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
     %0 = memref.alloc() : memref<3x125xi4>
-    memref.assume_alignment %0, 64 : memref<3x125xi4>
-    %1 = memref.load %0[%arg0,%arg1] : memref<3x125xi4>
+    %align0 =memref.assume_alignment %0, 64 : memref<3x125xi4>
+    %1 = memref.load %align0[%arg0,%arg1] : memref<3x125xi4>
     return %1 : i4
 }
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
@@ -73,9 +73,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
 //      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
-//      CHECK:   memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
+//      CHECK:   %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
 //      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
-//      CHECK:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+//      CHECK:   %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]]
 //      CHECK:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
 //      CHECK:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
 //      CHECK:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
@@ -88,9 +88,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
 // CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: index
 // CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
 //      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
-//      CHECK32:   memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
+//      CHECK32:   %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
 //      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
-//      CHECK32:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+//      CHECK32:   %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]]
 //      CHECK32:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
 //      CHECK32:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
 //      CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
@@ -350,8 +350,8 @@ func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
 
 func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
     %0 = memref.alloc() : memref<3x125xi4>
-    memref.assume_alignment %0, 64 : memref<3x125xi4>
-    memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
+    %align0 = memref.assume_alignment %0, 64 : memref<3x125xi4>
+    memref.store %arg2, %align0[%arg0,%arg1] : memref<3x125xi4>
     return
 }
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
@@ -359,7 +359,7 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
 //      CHECK: func @memref_store_i4_rank2(
 // CHECK-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
 //  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
-//  CHECK-DAG:   memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
+//  CHECK-DAG:   %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
 //  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
 //  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
 //  CHECK-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
@@ -369,8 +369,8 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
 //  CHECK-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
 //  CHECK-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
 //  CHECK-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
-//      CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
-//      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
+//      CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
+//      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
 //      CHECK:   return
 
 //  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
@@ -378,7 +378,7 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
 //      CHECK32: func @memref_store_i4_rank2(
 // CHECK32-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
 //  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
-//  CHECK32-DAG:   memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
+//  CHECK32-DAG:   %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
 //  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32
 //  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
 //  CHECK32-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
@@ -388,8 +388,8 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
 //  CHECK32-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
 //  CHECK32-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
 //  CHECK32-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
-//      CHECK32:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
-//      CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
+//      CHECK32:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
+//      CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
 //      CHECK32:   return
 
 // -----
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 34fc4775924e7..f908efb638446 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -878,7 +878,7 @@ func.func @invalid_memref_cast() {
 // alignment is not power of 2.
 func.func @assume_alignment(%0: memref<4x4xf16>) {
   // expected-error at +1 {{alignment must be power of 2}}
-  memref.assume_alignment %0, 12 : memref<4x4xf16>
+  %1 = memref.assume_alignment %0, 12 : memref<4x4xf16>
   return
 }
 
@@ -887,7 +887,7 @@ func.func @assume_alignment(%0: memref<4x4xf16>) {
 // 0 alignment value.
 func.func @assume_alignment(%0: memref<4x4xf16>) {
   // expected-error at +1 {{attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
-  memref.assume_alignment %0, 0 : memref<4x4xf16>
+  %1 = memref.assume_alignment %0, 0 : memref<4x4xf16>
   return
 }
 
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 7038a6ff744e4..38ee363a7d424 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -284,7 +284,7 @@ func.func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) {
 // CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
 func.func @assume_alignment(%0: memref<4x4xf16>) {
   // CHECK: memref.assume_alignment %[[MEMREF]], 16 : memref<4x4xf16>
-  memref.assume_alignment %0, 16 : memref<4x4xf16>
+  %1 = memref.assume_alignment %0, 16 : memref<4x4xf16>
   return
 }
 

>From 42c4e78d30693a7eec5438511a6fbad9305ba0ff Mon Sep 17 00:00:00 2001
From: Shay Kleiman <shayk at epgd045.me-corp.lan>
Date: Thu, 15 May 2025 09:30:03 +0300
Subject: [PATCH 3/6] Made AssumeAlignment Pure

Assume Alignment is now pure with its ub now deferred with poison
---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       | 21 ++++++++++++-------
 1 file changed, 14 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 1d07e4ca7fe84..b9a9dbadc6227 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -143,21 +143,28 @@ class AllocLikeOp<string mnemonic,
 //===----------------------------------------------------------------------===//
 
 def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
-      NoMemoryEffect, 
+      Pure, 
       ViewLikeOpInterface,
       SameOperandsAndResultType
     ]> {
   let summary =
       "assertion that gives alignment information to the input memref";
   let description = [{
-      The `assume_alignment` operation takes a memref and an integer of alignment
+      The `assume_alignment` operation takes a memref and an integer alignment
       value. It returns a new SSA value of the same memref type, but associated
       with the assertion that the underlying buffer is aligned to the given
-      alignment. If the buffer isn't aligned to the given alignment, the 
-      behavior is undefined.
-
-      This operation doesn't affect the semantics of a correct program. It's for
-      optimization only, and the optimization is best-effort.
+      alignment. 
+
+      If the buffer isn't actually aligned to the given alignment, this operation
+      itself does not cause undefined behavior. However, subsequent operations
+      that consume the resulting memref and rely on this asserted alignment for
+      correctness (e.g., to avoid hardware traps or to meet ISA requirements for
+      specific instructions) will produce a poison value if the assertion is false.
+
+      This operation doesn't affect the semantics of a program where the
+      alignment assertion holds true. It is intended for optimization purposes,
+      allowing the compiler to generate more efficient code based on the
+      alignment assumption. The optimization is best-effort.
   }];
   let arguments = (ins AnyMemRef:$memref,
                        ConfinedAttr<I32Attr, [IntPositive]>:$alignment);

>From 7de24e02fba311e9041dacfd4db091e7bd76302c Mon Sep 17 00:00:00 2001
From: Shay Kleiman <shayk at epgd045.me-corp.lan>
Date: Thu, 15 May 2025 15:00:28 +0300
Subject: [PATCH 4/6] Added ASM Result Naming to AssumeAlignment

Set AsmResultName for the operation to be assume_align
---
 mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td | 8 ++++----
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp         | 4 ++++
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index b9a9dbadc6227..71b92d4ae5f06 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -143,6 +143,7 @@ class AllocLikeOp<string mnemonic,
 //===----------------------------------------------------------------------===//
 
 def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
+      DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
       Pure, 
       ViewLikeOpInterface,
       SameOperandsAndResultType
@@ -150,16 +151,15 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
   let summary =
       "assertion that gives alignment information to the input memref";
   let description = [{
-      The `assume_alignment` operation takes a memref and an integer alignment
+      The `assume_alignment` operation takes a memref and an integer of alignment
       value. It returns a new SSA value of the same memref type, but associated
       with the assertion that the underlying buffer is aligned to the given
       alignment. 
 
-      If the buffer isn't actually aligned to the given alignment, this operation
+      If the buffer isn't aligned to the given alignment, this operation
       itself does not cause undefined behavior. However, subsequent operations
       that consume the resulting memref and rely on this asserted alignment for
-      correctness (e.g., to avoid hardware traps or to meet ISA requirements for
-      specific instructions) will produce a poison value if the assertion is false.
+      correctness will produce a poison value if the assertion is false.
 
       This operation doesn't affect the semantics of a program where the
       alignment assertion holds true. It is intended for optimization purposes,
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a0237c18cf2fe..9f99035d06a2c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -527,6 +527,10 @@ LogicalResult AssumeAlignmentOp::verify() {
   return success();
 }
 
+void AssumeAlignmentOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "assume_align");
+}
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//

>From 1c158c1d00e1563afc36499cd4d8844a6a8b5b7c Mon Sep 17 00:00:00 2001
From: Shay Kleiman <shayk at epgd045.me-corp.lan>
Date: Fri, 16 May 2025 14:27:00 +0300
Subject: [PATCH 5/6] Modified assume_alignment description

Changed assertion to assumption and improved description of poison
---
 mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td | 16 ++++++----------
 .../Dialect/GPU/Transforms/EliminateBarriers.cpp |  3 ---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp         |  3 ++-
 .../MemRef/Transforms/ExpandStridedMetadata.cpp  |  5 +++--
 4 files changed, 11 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 830cd90cadc25..54ac899f96f06 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -144,25 +144,21 @@ class AllocLikeOp<string mnemonic,
 
 def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-      Pure, 
+      Pure,
       ViewLikeOpInterface,
       SameOperandsAndResultType
     ]> {
   let summary =
-      "assertion that gives alignment information to the input memref";
+      "assumption that gives alignment information to the input memref";
   let description = [{
-      The `assume_alignment` operation takes a memref and an integer of alignment
+      The `assume_alignment` operation takes a memref and an integer alignment
       value. It returns a new SSA value of the same memref type, but associated
-      with the assertion that the underlying buffer is aligned to the given
+      with the assumption that the underlying buffer is aligned to the given
       alignment. 
 
-      If the buffer isn't aligned to the given alignment, this operation
-      itself does not cause undefined behavior. However, subsequent operations
-      that consume the resulting memref and rely on this asserted alignment for
-      correctness will produce a poison value if the assertion is false.
-
+      If the buffer isn't aligned to the given alignment, its result is poison.
       This operation doesn't affect the semantics of a program where the
-      alignment assertion holds true. It is intended for optimization purposes,
+      alignment assumption holds true. It is intended for optimization purposes,
       allowing the compiler to generate more efficient code based on the
       alignment assumption. The optimization is best-effort.
   }];
diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
index 43e04232ffe7d..912d4dc99a885 100644
--- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
@@ -44,8 +44,6 @@ using namespace mlir::gpu;
 // The functions below provide interface-like verification, but are too specific
 // to barrier elimination to become interfaces.
 
-
-
 /// Returns `true` if the op is defines the parallel region that is subject to
 /// barrier synchronization.
 static bool isParallelRegionBoundary(Operation *op) {
@@ -96,7 +94,6 @@ collectEffects(Operation *op,
   if (ignoreBarriers && isa<BarrierOp>(op))
     return true;
 
-
   // Collect effect instances the operation. Note that the implementation of
   // getEffects erases all effect instances that have the type other than the
   // template parameter so we collect them first in a local buffer and then
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9f99035d06a2c..82702789c2913 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -527,7 +527,8 @@ LogicalResult AssumeAlignmentOp::verify() {
   return success();
 }
 
-void AssumeAlignmentOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+void AssumeAlignmentOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
   setNameFn(getResult(), "assume_align");
 }
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 60838696df4ad..cfd529c46a41d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -937,12 +937,13 @@ struct ExtractStridedMetadataOpAssumeAlignmentFolder
 
   LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
                                 PatternRewriter &rewriter) const override {
-    auto assumeAlignmentOp = 
+    auto assumeAlignmentOp =
         op.getSource().getDefiningOp<memref::AssumeAlignmentOp>();
     if (!assumeAlignmentOp)
       return failure();
 
-    rewriter.replaceOpWithNewOp<memref::ExtractStridedMetadataOp>(op, assumeAlignmentOp.getViewSource());
+    rewriter.replaceOpWithNewOp<memref::ExtractStridedMetadataOp>(
+        op, assumeAlignmentOp.getViewSource());
     return success();
   }
 };

>From e7a781f2777863a1ad7269f16ec0a2243b495e1c Mon Sep 17 00:00:00 2001
From: Shay Kleiman <shay.kleiman at mobileye.com>
Date: Sun, 18 May 2025 14:29:04 +0300
Subject: [PATCH 6/6] Fixed assume alignment mlir

---
 .../MemRef/assume-alignment-runtime-verification.mlir       | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
index 394648d1b8bfa..8f74976c59773 100644
--- a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
@@ -10,7 +10,7 @@ func.func @main() {
   // This buffer is properly aligned. There should be no error.
   // CHECK-NOT: ^ memref is not aligned to 8
   %alloc = memref.alloca() : memref<5xf64>
-  memref.assume_alignment %alloc, 8 : memref<5xf64>
+  %0 = memref.assume_alignment %alloc, 8 : memref<5xf64>
 
   // Construct a memref descriptor with a pointer that is not aligned to 4.
   // This cannot be done with just the memref dialect. We have to resort to
@@ -28,10 +28,10 @@ func.func @main() {
   %buffer = builtin.unrealized_conversion_cast %10 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<1xf32>
 
   //      CHECK: ERROR: Runtime op verification failed
-  // CHECK-NEXT: "memref.assume_alignment"(%{{.*}}) <{alignment = 4 : i32}> : (memref<1xf32>) -> ()
+  // CHECK-NEXT: %[[ASSUME:.*]] = "memref.assume_alignment"(%{{.*}}) <{alignment = 4 : i32}> : (memref<1xf32>)
   // CHECK-NEXT: ^ memref is not aligned to 4
   // CHECK-NEXT: Location: loc({{.*}})
-  memref.assume_alignment %buffer, 4 : memref<1xf32>
+  %assume = memref.assume_alignment %buffer, 4 : memref<1xf32>
 
   return
 }



More information about the Mlir-commits mailing list