[Mlir-commits] [mlir] [mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases (PR #145235)

Andrzej Warzyński llvmlistbot at llvm.org
Thu Jun 26 03:52:54 PDT 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/145235

>From 6dbaa082d07081714fe93c538d98e948bd100164 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 21 Jun 2025 15:09:13 +0100
Subject: [PATCH 1/5] [mlir][linalg] Prevent hoisting of transfer pairs in the
 presence of aliases
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch adds additional checks to the hoisting logic to prevent
hoisting of `vector.transfer_read`/`vector.transfer_write` pairs when
the underlying `memref` has users that introduce aliases via operations
implementing `ViewLikeOpInterface`.

Note: This may conservatively block some valid hoisting opportunities
and could impact performance. However, as demonstrated by the included
tests, the current behavior is too permissive and can lead to incorrect
transformations.

If this change prevents hoisting in cases that are provably safe, please
share a minimal repro — I’d be happy to explore ways to relax the check.
---
 .../Dialect/Linalg/Transforms/Hoisting.cpp    |  15 +-
 mlir/test/Dialect/Linalg/hoisting.mlir        | 140 ++++++++++++++++++
 2 files changed, 154 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 707b63ff9335b..808925a934979 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -303,7 +303,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
       //   1. indices, vector type and permutation map are the same (i.e., the
       //      transfer_read/transfer_write ops are matching),
       //   2. source operands for transfer.{read|write} do not originate from
-      //      Ops implementing ViewLikeOpInterface.
+      //      nor have users that are Ops implementing ViewLikeOpInterface.
       //   3. no other operations in the loop access the same memref except
       //      for transfer_read/transfer_write accessing statically disjoint
       //      slices.
@@ -312,14 +312,27 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
           transferRead.getPermutationMap() != transferWrite.getPermutationMap())
         return WalkResult::advance();
 
+      // Check 2. for xfer_read
       auto *source = transferRead.getBase().getDefiningOp();
       if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
         return WalkResult::advance();
 
+      auto base = transferRead.getBase();
+      for (auto *user : base.getUsers())
+        if (isa_and_nonnull<ViewLikeOpInterface>(user))
+          return WalkResult::advance();
+
+      // Check 2. for xfer_wrire
       source = transferWrite.getBase().getDefiningOp();
       if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
         return WalkResult::advance();
 
+      base = transferWrite.getBase();
+      for (auto *user : base.getUsers())
+        if (isa_and_nonnull<ViewLikeOpInterface>(user))
+          return WalkResult::advance();
+
+      // Check 1. + 3.
       // TODO: may want to memoize this information for performance but it
       // likely gets invalidated often.
       DominanceInfo dom(loop);
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 8be4e1b79c52c..aec69465ca53c 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -1,5 +1,145 @@
 // RUN: mlir-opt  -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
 
+///----------------------------------------------------------------------------------------
+/// Tests for vector.transfer_read + vector.transfer_write pairs
+///
+/// * Indices are static
+/// * Single loop
+///----------------------------------------------------------------------------------------
+
+// The most basic example - hoisting is safe.
+
+// CHECK-LABEL:   func.func @hoist_basic_vector_xfer_pair(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index) {
+func.func @hoist_basic_vector_xfer_pair(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:           %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
+// CHECK:             %[[VAL_6:.*]] = "some_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             scf.yield %[[VAL_6]] : vector<1xf32>
+// CHECK:           }
+// CHECK:           vector.transfer_write %[[SCF]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+  scf.for %i = %lb to %ub step %step {
+      %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write inside the loop.
+
+// CHECK-LABEL:   func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             vector.transfer_write %[[IN]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:             %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:             %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:           }
+
+  scf.for %i = %lb to %ub step %step {
+      vector.transfer_write %in, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+
+      %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write into _an alias_ of the %mem Op that is used by the
+// original xfer pair.
+
+// CHECK-LABEL:   func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [1, 1] [1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             vector.transfer_write %[[IN]], %[[SV]][%[[C0]], %[[C0]]] {{.*}} : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+// CHECK:             %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:             %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:           }
+
+  %sv = memref.subview %mem[0, 0][1, 1][1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+  scf.for %i = %lb to %ub step %step {
+      vector.transfer_write %in, %sv[%c0, %c0] : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+
+      %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 ///----------------------------------------------------------------------------------------
 /// Tests for vector.transfer_read + vector.transfer_write pairs
 ///

>From 0fe657aabb67a22e7dd79b029347a4f16d245c0e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 25 Jun 2025 10:20:22 +0100
Subject: [PATCH 2/5] fixup! [mlir][linalg] Prevent hoisting of transfer pairs
 in the presence of aliases

1. Relax the conditions in the case of `memref.assume_alignment`. This
   unblocks hoisting for examples like the one reported in
   https://github.com/llvm/llvm-project/issues/144825 (see the link in
   the top post: "detailed example pls refer to example").
2. When checking the source operand for the xfer Ops, we only need to
   look at either xfer_write or xfer_read (we already know that the
   source is identical). The corresponding logic has been simplified.
---
 .../Dialect/Linalg/Transforms/Hoisting.cpp    | 40 ++++++++++-------
 mlir/test/Dialect/Linalg/hoisting.mlir        | 45 +++++++++++++++++++
 2 files changed, 70 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 808925a934979..1a5cb322074aa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -307,32 +307,41 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
       //   3. no other operations in the loop access the same memref except
       //      for transfer_read/transfer_write accessing statically disjoint
       //      slices.
+
+      // Check 1.
       if (transferRead.getIndices() != transferWrite.getIndices() ||
           transferRead.getVectorType() != transferWrite.getVectorType() ||
           transferRead.getPermutationMap() != transferWrite.getPermutationMap())
         return WalkResult::advance();
 
-      // Check 2. for xfer_read
-      auto *source = transferRead.getBase().getDefiningOp();
-      if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
-        return WalkResult::advance();
-
+      // Check 2. Note, since both xfer Ops share the source, we only need to
+      // look at one of them.
       auto base = transferRead.getBase();
-      for (auto *user : base.getUsers())
-        if (isa_and_nonnull<ViewLikeOpInterface>(user))
+      auto *source = base.getDefiningOp();
+      if (source) {
+        // NOTE: We treat `memref.assume_alignment` as a special case:
+        //  1. If it has exactly two uses then these have to be the xfer Ops
+        //    being looked at.
+        //  2. Otherwise, there are other users that we should take into
+        //    account
+        // In the case of 1., it is safe to look past AssumeAlignmentOp,
+        // i.e. at the defining Op of the input MemRef, provided that:
+        //  * the original MemRef has only one use (i.e.
+        //  `memref.assume_alignment`)
+        if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
+          Value memPreAlignment = assume.getMemref();
+          if (base.hasNUses(2) && memPreAlignment.hasOneUse())
+            source = memPreAlignment.getDefiningOp();
+        }
+        if (isa_and_nonnull<ViewLikeOpInterface>(source))
           return WalkResult::advance();
+      }
 
-      // Check 2. for xfer_wrire
-      source = transferWrite.getBase().getDefiningOp();
-      if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
-        return WalkResult::advance();
-
-      base = transferWrite.getBase();
       for (auto *user : base.getUsers())
         if (isa_and_nonnull<ViewLikeOpInterface>(user))
           return WalkResult::advance();
 
-      // Check 1. + 3.
+      // Check 3.
       // TODO: may want to memoize this information for performance but it
       // likely gets invalidated often.
       DominanceInfo dom(loop);
@@ -371,7 +380,8 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
       // Hoist write after.
       transferWrite->moveAfter(loop);
 
-      // Rewrite `loop` with new yields by cloning and erase the original loop.
+      // Rewrite `loop` with new yields by cloning and erase the original
+      // loop.
       IRRewriter rewriter(transferRead.getContext());
       NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
                                      ArrayRef<BlockArgument> newBBArgs) {
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index aec69465ca53c..c0ec08d08da4e 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -140,6 +140,51 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Similar as the example above, but the memory access is done via
+// memref.assume_alignment. Hoisting is safe as the only users of the
+// "allignment" Op are the xfer Ops within the loop that we want to hoist.
+
+// CHECK-LABEL:   func.func @hoist_basic_vector_xfer_pair_with_assume_align(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @hoist_basic_vector_xfer_pair_with_assume_align(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[AA:.*]] = memref.assume_alignment %[[MEM]], 4 : memref<?x?xf32>
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[AA]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:           %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]]  iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
+// CHECK:             %[[USE:.*]] = "some_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:           }
+// CHECK:           vector.transfer_write %[[SCF]], %[[AA]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+
+  %aa = memref.assume_alignment %mem, 4 : memref<?x?xf32>
+  scf.for %i = %lb to %ub step %step {
+      %r0 = vector.transfer_read %aa[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u0, %aa[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 ///----------------------------------------------------------------------------------------
 /// Tests for vector.transfer_read + vector.transfer_write pairs
 ///

>From c160d764ea1d5c67a834404a8e19cea7d2c55154 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 25 Jun 2025 11:00:23 +0100
Subject: [PATCH 3/5] fixup! fixup! [mlir][linalg] Prevent hoisting of transfer
 pairs in the presence of aliases

Make sure only uses inside the loop are counted
---
 mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 1a5cb322074aa..4b1fdd4a4a3bc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -330,7 +330,12 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
         //  `memref.assume_alignment`)
         if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
           Value memPreAlignment = assume.getMemref();
-          if (base.hasNUses(2) && memPreAlignment.hasOneUse())
+          auto numInLoopUses =
+              llvm::count_if(base.getUses(), [&loop](OpOperand &use) {
+                return loop->isAncestor(use.getOwner());
+              });
+
+          if (numInLoopUses && memPreAlignment.hasOneUse())
             source = memPreAlignment.getDefiningOp();
         }
         if (isa_and_nonnull<ViewLikeOpInterface>(source))

>From a991957918fe414fa37c0f5a0b40249ea0cbbbd1 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 26 Jun 2025 10:59:49 +0100
Subject: [PATCH 4/5] fixup! fixup! fixup! [mlir][linalg] Prevent hoisting of
 transfer pairs in the presence of aliases

Incorporate suggestions from HanHan, update the description.
---
 .../Dialect/Linalg/Transforms/Hoisting.cpp    | 25 ++++++++++---------
 1 file changed, 13 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 4b1fdd4a4a3bc..d833e04d60264 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -319,15 +319,17 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
       auto base = transferRead.getBase();
       auto *source = base.getDefiningOp();
       if (source) {
-        // NOTE: We treat `memref.assume_alignment` as a special case:
-        //  1. If it has exactly two uses then these have to be the xfer Ops
-        //    being looked at.
-        //  2. Otherwise, there are other users that we should take into
-        //    account
-        // In the case of 1., it is safe to look past AssumeAlignmentOp,
-        // i.e. at the defining Op of the input MemRef, provided that:
-        //  * the original MemRef has only one use (i.e.
-        //  `memref.assume_alignment`)
+        // NOTE: We treat `memref.assume_alignment` as a special case.
+        //
+        // The idea is that it is safe to look past AssumeAlignmemtOp (i.e.
+        // MemRef _before_ alignment) iff:
+        //  1. It has exactly two uses (these have to be the xfer Ops
+        //     being looked at).
+        //  2. The original MemRef has only one use (i.e.
+        //     AssumeAlignmentOp).
+        //
+        // Relaxing these conditions will most likely require proper alias
+        // analysis.
         if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
           Value memPreAlignment = assume.getMemref();
           auto numInLoopUses =
@@ -342,9 +344,8 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
           return WalkResult::advance();
       }
 
-      for (auto *user : base.getUsers())
-        if (isa_and_nonnull<ViewLikeOpInterface>(user))
-          return WalkResult::advance();
+      if (llvm::any_of(base.getUsers(), llvm::IsaPred<ViewLikeOpInterface>))
+        return WalkResult::advance();
 
       // Check 3.
       // TODO: may want to memoize this information for performance but it

>From d4d95dac72b5ee676e49865f86f60066171b9dc2 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 26 Jun 2025 11:24:56 +0100
Subject: [PATCH 5/5] fixup! fixup! fixup! fixup! [mlir][linalg] Prevent
 hoisting of transfer pairs in the presence of aliases

Extra test
---
 mlir/test/Dialect/Linalg/hoisting.mlir | 64 ++++++++++++++++++++++----
 1 file changed, 54 insertions(+), 10 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index c0ec08d08da4e..aa0b97a4787fa 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -3,8 +3,8 @@
 ///----------------------------------------------------------------------------------------
 /// Tests for vector.transfer_read + vector.transfer_write pairs
 ///
-/// * Indices are static
-/// * Single loop
+/// * Nested inside a single loop
+//  * Indices are constant
 ///----------------------------------------------------------------------------------------
 
 // The most basic example - hoisting is safe.
@@ -23,13 +23,13 @@ func.func @hoist_basic_vector_xfer_pair(
 // CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
 // CHECK:           %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
-// CHECK:             %[[VAL_6:.*]] = "some_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             %[[VAL_6:.*]] = "val_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
 // CHECK:             scf.yield %[[VAL_6]] : vector<1xf32>
 // CHECK:           }
 // CHECK:           vector.transfer_write %[[SCF]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
   scf.for %i = %lb to %ub step %step {
       %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
-      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      %u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
       vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
   }
   return
@@ -66,7 +66,7 @@ func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
 // CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
 // CHECK:             vector.transfer_write %[[IN]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
 // CHECK:             %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
-// CHECK:             %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
 // CHECK:             vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
 // CHECK:           }
 
@@ -74,7 +74,7 @@ func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
       vector.transfer_write %in, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
 
       %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
-      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      %u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
       vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
   }
   return
@@ -113,7 +113,7 @@ func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
 // CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
 // CHECK:             vector.transfer_write %[[IN]], %[[SV]][%[[C0]], %[[C0]]] {{.*}} : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
 // CHECK:             %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
-// CHECK:             %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
 // CHECK:             vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
 // CHECK:           }
 
@@ -122,7 +122,7 @@ func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
       vector.transfer_write %in, %sv[%c0, %c0] : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
 
       %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
-      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      %u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
       vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
   }
   return
@@ -160,14 +160,14 @@ func.func @hoist_basic_vector_xfer_pair_with_assume_align(
 // CHECK:           %[[AA:.*]] = memref.assume_alignment %[[MEM]], 4 : memref<?x?xf32>
 // CHECK:           %[[READ:.*]] = vector.transfer_read %[[AA]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
 // CHECK:           %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]]  iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
-// CHECK:             %[[USE:.*]] = "some_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             %[[USE:.*]] = "val_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
 // CHECK:           }
 // CHECK:           vector.transfer_write %[[SCF]], %[[AA]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
 
   %aa = memref.assume_alignment %mem, 4 : memref<?x?xf32>
   scf.for %i = %lb to %ub step %step {
       %r0 = vector.transfer_read %aa[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
-      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      %u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
       vector.transfer_write %u0, %aa[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
   }
   return
@@ -185,6 +185,50 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Similar as the example above, but hoisting is not safe due to extra memory
+// access inside the loop via the original memref.
+
+// CHECK-LABEL:   func.func @negative_hoist_basic_vector_xfer_pair_with_assume_align(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_with_assume_align(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[AA:.*]] = memref.assume_alignment %[[MEM]], 4 : memref<?x?xf32>
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             %[[READ:.*]] = vector.transfer_read %[[AA]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:             "mem_use"(%[[MEM]])
+// CHECK:             vector.transfer_write %[[READ]], %[[AA]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:           }
+
+  %aa = memref.assume_alignment %mem, 4 : memref<?x?xf32>
+  scf.for %i = %lb to %ub step %step {
+      %r0 = vector.transfer_read %aa[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      "mem_use"(%mem) : (memref<?x?xf32>) -> ()
+      vector.transfer_write %r0, %aa[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 ///----------------------------------------------------------------------------------------
 /// Tests for vector.transfer_read + vector.transfer_write pairs
 ///



More information about the Mlir-commits mailing list