[Mlir-commits] [mlir] [mlir][hoisting] Support memref.assume_alignment in linalg hoisting (PR #144843)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 18 23:03:01 PDT 2025
https://github.com/xiangzh1 created https://github.com/llvm/llvm-project/pull/144843
The recent updates of AssumeAlignmentOp will affect linalg hoisting optimization.
We find it has regression on "hoist load/store out of loop".
The flowing issue list more detail:
related issue : [144825](https://github.com/llvm/llvm-project/issues/144825)
This patch tend to fix this problem due to the assume_alignment just mark memref's alignment,
the linalg hoisting should check its memref operand not it self.
>From 3b387e7baae4a098ee8bd90ecf1797967c2c43bf Mon Sep 17 00:00:00 2001
From: Zhang Xiang <xiang.zhang at iluvatar.com>
Date: Wed, 18 Jun 2025 17:10:15 +0800
Subject: [PATCH 1/2] [mlir][NFC] Pre-commit test for linalg hoisting
---
mlir/test/Dialect/Linalg/hoisting.mlir | 51 ++++++++++++++++++++++++++
1 file changed, 51 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 318edca73cce1..67dfe7a2af98b 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -802,3 +802,54 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// Test hoisting of vector.transfer_read/transfer_write pairs with same location
+// and this location is marked with assume_align.
+
+// CHECK-LABEL: func.func @hoist_vector_transfer_read_write() {
+// CHECK: %c0 = arith.constant 0 : index
+// CHECK-NEXT: %c256 = arith.constant 256 : index
+// CHECK-NEXT: %c4096 = arith.constant 4096 : index
+// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f16
+// CHECK-NEXT: %alloc = memref.alloc() : memref<4096x4096xf16>
+// CHECK-NEXT: %alloc_0 = memref.alloc() : memref<4096x4096xf16>
+// CHECK-NEXT: %assume_align = memref.assume_alignment %alloc, 64 : memref<4096x4096xf16>
+// CHECK-NEXT: scf.for %arg0 = %c256 to %c4096 step %c256 {
+// CHECK-NEXT: %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+// CHECK-NEXT: %1 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+// CHECK-NEXT: %2 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %1, %0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+// CHECK-NEXT: vector.transfer_write %2, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+func.func @hoist_vector_transfer_read_write() {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c256 = arith.constant 256 : index
+ %c4096 = arith.constant 4096 : index
+ %cst_0 = arith.constant 0.000000e+00 : f16
+ %m0 = memref.alloc() : memref<4096x4096xf16>
+ %m1 = memref.alloc() : memref<4096x4096xf16>
+ %assume_align_0 = memref.assume_alignment %m0, 64 : memref<4096x4096xf16>
+ %assume_align_1 = memref.assume_alignment %m1, 64 : memref<4096x4096xf16>
+ scf.for %arg0 = %c256 to %c4096 step %c256 {
+ %1 = vector.transfer_read %assume_align_0[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+ %2 = vector.transfer_read %m1[%arg0, %arg0], %cst_0 {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+ %3 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %2, %1 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+ vector.transfer_write %3, %assume_align_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
>From 706c1daea66acd1b899aac37a6716a1bf9255c3c Mon Sep 17 00:00:00 2001
From: Zhang Xiang <xiang.zhang at iluvatar.com>
Date: Wed, 18 Jun 2025 14:05:04 +0800
Subject: [PATCH 2/2] [mlir][hoisting] Support memref.assume_alignment in
linalg hoisting
All ViewLike operations are excluded by hoisting optimization. But
assume_alignment just mark memref's alignment, we should check its
memref instead of itself.
---
.../Dialect/Linalg/Transforms/Hoisting.cpp | 26 +++++++++++++++----
mlir/test/Dialect/Linalg/hoisting.mlir | 11 ++++----
2 files changed, 27 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 707b63ff9335b..b949b06631484 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -199,6 +199,24 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
return true;
}
+static bool skipViewLike(Operation *source0, Operation *source1) {
+ bool viewLikeCheck = true;
+ auto assumeAlignOp = dyn_cast_or_null<memref::AssumeAlignmentOp>(source0);
+ if (assumeAlignOp && source0 == source1) {
+ Value sourceMemRef = assumeAlignOp.getMemref();
+ Operation *sourceOp = sourceMemRef.getDefiningOp();
+ return isa_and_nonnull<ViewLikeOpInterface>(sourceOp);
+ }
+
+ if (source0 && isa_and_nonnull<ViewLikeOpInterface>(source0))
+ return true;
+
+ if (source1 && isa_and_nonnull<ViewLikeOpInterface>(source1))
+ return true;
+
+ return false;
+}
+
void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
bool verifyNonZeroTrip) {
bool changed = true;
@@ -312,12 +330,10 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
return WalkResult::advance();
- auto *source = transferRead.getBase().getDefiningOp();
- if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
- return WalkResult::advance();
+ auto *source0 = transferRead.getBase().getDefiningOp();
+ auto *source1 = transferWrite.getBase().getDefiningOp();
- source = transferWrite.getBase().getDefiningOp();
- if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
+ if (skipViewLike(source0, source1))
return WalkResult::advance();
// TODO: may want to memoize this information for performance but it
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 67dfe7a2af98b..c58074e40c5f4 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -816,12 +816,13 @@ module attributes {transform.with_named_sequence} {
// CHECK-NEXT: %alloc = memref.alloc() : memref<4096x4096xf16>
// CHECK-NEXT: %alloc_0 = memref.alloc() : memref<4096x4096xf16>
// CHECK-NEXT: %assume_align = memref.assume_alignment %alloc, 64 : memref<4096x4096xf16>
-// CHECK-NEXT: scf.for %arg0 = %c256 to %c4096 step %c256 {
-// CHECK-NEXT: %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
-// CHECK-NEXT: %1 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
-// CHECK-NEXT: %2 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %1, %0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
-// CHECK-NEXT: vector.transfer_write %2, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
+// CHECK-NEXT: %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+// CHECK-NEXT: %1 = scf.for %arg0 = %c256 to %c4096 step %c256 iter_args(%arg1 = %0) -> (vector<16x16xf16>) {
+// CHECK-NEXT: %2 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+// CHECK-NEXT: %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %2, %arg1 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+// CHECK-NEXT: scf.yield %3 : vector<16x16xf16>
// CHECK-NEXT: }
+// CHECK-NEXT: vector.transfer_write %1, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
// CHECK-NEXT: return
// CHECK-NEXT: }
More information about the Mlir-commits
mailing list