[Mlir-commits] [mlir] [mlir][hoisting] Support memref.assume_alignment in linalg hoisting (PR #144843)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 18 23:03:01 PDT 2025


https://github.com/xiangzh1 created https://github.com/llvm/llvm-project/pull/144843

The recent updates of AssumeAlignmentOp will affect linalg hoisting optimization.
We find it has regression on "hoist load/store out of loop".
The flowing issue list more detail:

related issue : [144825](https://github.com/llvm/llvm-project/issues/144825)

This patch tend to fix this problem due to the assume_alignment just mark memref's alignment, 
the linalg hoisting should check its memref operand not it self.


>From 3b387e7baae4a098ee8bd90ecf1797967c2c43bf Mon Sep 17 00:00:00 2001
From: Zhang Xiang <xiang.zhang at iluvatar.com>
Date: Wed, 18 Jun 2025 17:10:15 +0800
Subject: [PATCH 1/2] [mlir][NFC] Pre-commit test for linalg hoisting

---
 mlir/test/Dialect/Linalg/hoisting.mlir | 51 ++++++++++++++++++++++++++
 1 file changed, 51 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 318edca73cce1..67dfe7a2af98b 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -802,3 +802,54 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// Test hoisting of vector.transfer_read/transfer_write pairs with same location
+// and this location is marked with assume_align.
+
+// CHECK-LABEL:  func.func @hoist_vector_transfer_read_write() {
+// CHECK:          %c0 = arith.constant 0 : index
+// CHECK-NEXT:     %c256 = arith.constant 256 : index
+// CHECK-NEXT:     %c4096 = arith.constant 4096 : index
+// CHECK-NEXT:     %cst = arith.constant 0.000000e+00 : f16
+// CHECK-NEXT:     %alloc = memref.alloc() : memref<4096x4096xf16>
+// CHECK-NEXT:     %alloc_0 = memref.alloc() : memref<4096x4096xf16>
+// CHECK-NEXT:     %assume_align = memref.assume_alignment %alloc, 64 : memref<4096x4096xf16>
+// CHECK-NEXT:     scf.for %arg0 = %c256 to %c4096 step %c256 {
+// CHECK-NEXT:       %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+// CHECK-NEXT:       %1 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+// CHECK-NEXT:       %2 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %1, %0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+// CHECK-NEXT:       vector.transfer_write %2, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     return
+// CHECK-NEXT:   }
+
+func.func @hoist_vector_transfer_read_write() {
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c256 = arith.constant 256 : index
+  %c4096 = arith.constant 4096 : index
+  %cst_0 = arith.constant 0.000000e+00 : f16
+  %m0 = memref.alloc() : memref<4096x4096xf16>
+  %m1 = memref.alloc() : memref<4096x4096xf16>
+  %assume_align_0 = memref.assume_alignment %m0, 64 : memref<4096x4096xf16>
+  %assume_align_1 = memref.assume_alignment %m1, 64 : memref<4096x4096xf16>
+  scf.for %arg0 = %c256 to %c4096 step %c256 {
+    %1 = vector.transfer_read %assume_align_0[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+    %2 = vector.transfer_read %m1[%arg0, %arg0], %cst_0 {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+    %3 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %2, %1 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+    vector.transfer_write %3, %assume_align_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+        transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

>From 706c1daea66acd1b899aac37a6716a1bf9255c3c Mon Sep 17 00:00:00 2001
From: Zhang Xiang <xiang.zhang at iluvatar.com>
Date: Wed, 18 Jun 2025 14:05:04 +0800
Subject: [PATCH 2/2] [mlir][hoisting] Support memref.assume_alignment in
 linalg hoisting

All ViewLike operations are excluded by hoisting optimization. But
assume_alignment just mark memref's alignment, we should check its
memref instead of itself.
---
 .../Dialect/Linalg/Transforms/Hoisting.cpp    | 26 +++++++++++++++----
 mlir/test/Dialect/Linalg/hoisting.mlir        | 11 ++++----
 2 files changed, 27 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 707b63ff9335b..b949b06631484 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -199,6 +199,24 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
   return true;
 }
 
+static bool skipViewLike(Operation *source0, Operation *source1) {
+  bool viewLikeCheck = true;
+  auto assumeAlignOp = dyn_cast_or_null<memref::AssumeAlignmentOp>(source0);
+  if (assumeAlignOp && source0 == source1) {
+    Value sourceMemRef = assumeAlignOp.getMemref();
+    Operation *sourceOp = sourceMemRef.getDefiningOp();
+    return isa_and_nonnull<ViewLikeOpInterface>(sourceOp);
+  }
+
+  if (source0 && isa_and_nonnull<ViewLikeOpInterface>(source0))
+    return true;
+
+  if (source1 && isa_and_nonnull<ViewLikeOpInterface>(source1))
+    return true;
+
+  return false;
+}
+
 void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
                                                  bool verifyNonZeroTrip) {
   bool changed = true;
@@ -312,12 +330,10 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
           transferRead.getPermutationMap() != transferWrite.getPermutationMap())
         return WalkResult::advance();
 
-      auto *source = transferRead.getBase().getDefiningOp();
-      if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
-        return WalkResult::advance();
+      auto *source0 = transferRead.getBase().getDefiningOp();
+      auto *source1 = transferWrite.getBase().getDefiningOp();
 
-      source = transferWrite.getBase().getDefiningOp();
-      if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
+      if (skipViewLike(source0, source1))
         return WalkResult::advance();
 
       // TODO: may want to memoize this information for performance but it
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 67dfe7a2af98b..c58074e40c5f4 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -816,12 +816,13 @@ module attributes {transform.with_named_sequence} {
 // CHECK-NEXT:     %alloc = memref.alloc() : memref<4096x4096xf16>
 // CHECK-NEXT:     %alloc_0 = memref.alloc() : memref<4096x4096xf16>
 // CHECK-NEXT:     %assume_align = memref.assume_alignment %alloc, 64 : memref<4096x4096xf16>
-// CHECK-NEXT:     scf.for %arg0 = %c256 to %c4096 step %c256 {
-// CHECK-NEXT:       %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
-// CHECK-NEXT:       %1 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
-// CHECK-NEXT:       %2 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %1, %0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
-// CHECK-NEXT:       vector.transfer_write %2, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
+// CHECK-NEXT:     %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+// CHECK-NEXT:     %1 = scf.for %arg0 = %c256 to %c4096 step %c256 iter_args(%arg1 = %0) -> (vector<16x16xf16>) {
+// CHECK-NEXT:       %2 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+// CHECK-NEXT:       %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %2, %arg1 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+// CHECK-NEXT:       scf.yield %3 : vector<16x16xf16>
 // CHECK-NEXT:     }
+// CHECK-NEXT:     vector.transfer_write %1, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
 // CHECK-NEXT:     return
 // CHECK-NEXT:   }
 



More information about the Mlir-commits mailing list