[Mlir-commits] [mlir] [mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases (PR #145235)

Andrzej Warzyński llvmlistbot at llvm.org
Wed Jun 25 03:00:55 PDT 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/145235

>From f7218d71dd69f63b14e6fb5d3da06228982b7423 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 21 Jun 2025 15:09:13 +0100
Subject: [PATCH 1/3] [mlir][linalg] Prevent hoisting of transfer pairs in the
 presence of aliases
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch adds additional checks to the hoisting logic to prevent
hoisting of `vector.transfer_read`/`vector.transfer_write` pairs when
the underlying `memref` has users that introduce aliases via operations
implementing `ViewLikeOpInterface`.

Note: This may conservatively block some valid hoisting opportunities
and could impact performance. However, as demonstrated by the included
tests, the current behavior is too permissive and can lead to incorrect
transformations.

If this change prevents hoisting in cases that are provably safe, please
share a minimal repro — I’d be happy to explore ways to relax the check.
---
 .../Dialect/Linalg/Transforms/Hoisting.cpp    |  15 +-
 mlir/test/Dialect/Linalg/hoisting.mlir        | 133 ++++++++++++++++++
 2 files changed, 147 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 707b63ff9335b..808925a934979 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -303,7 +303,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
       //   1. indices, vector type and permutation map are the same (i.e., the
       //      transfer_read/transfer_write ops are matching),
       //   2. source operands for transfer.{read|write} do not originate from
-      //      Ops implementing ViewLikeOpInterface.
+      //      nor have users that are Ops implementing ViewLikeOpInterface.
       //   3. no other operations in the loop access the same memref except
       //      for transfer_read/transfer_write accessing statically disjoint
       //      slices.
@@ -312,14 +312,27 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
           transferRead.getPermutationMap() != transferWrite.getPermutationMap())
         return WalkResult::advance();
 
+      // Check 2. for xfer_read
       auto *source = transferRead.getBase().getDefiningOp();
       if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
         return WalkResult::advance();
 
+      auto base = transferRead.getBase();
+      for (auto *user : base.getUsers())
+        if (isa_and_nonnull<ViewLikeOpInterface>(user))
+          return WalkResult::advance();
+
+      // Check 2. for xfer_wrire
       source = transferWrite.getBase().getDefiningOp();
       if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
         return WalkResult::advance();
 
+      base = transferWrite.getBase();
+      for (auto *user : base.getUsers())
+        if (isa_and_nonnull<ViewLikeOpInterface>(user))
+          return WalkResult::advance();
+
+      // Check 1. + 3.
       // TODO: may want to memoize this information for performance but it
       // likely gets invalidated often.
       DominanceInfo dom(loop);
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 318edca73cce1..fd5a3edfb743f 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -1,5 +1,138 @@
 // RUN: mlir-opt  -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
 
+// The most basic example - hoisting is safe.
+
+// CHECK-LABEL:   func.func @hoist_basic_vector_xfer_pair(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index) {
+func.func @hoist_basic_vector_xfer_pair(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:           %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
+// CHECK:             %[[VAL_6:.*]] = "some_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             scf.yield %[[VAL_6]] : vector<1xf32>
+// CHECK:           }
+// CHECK:           vector.transfer_write %[[SCF]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+  scf.for %i = %lb to %ub step %step {
+      %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write inside the loop.
+
+// CHECK-LABEL:   func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             vector.transfer_write %[[IN]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:             %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:             %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:           }
+
+  scf.for %i = %lb to %ub step %step {
+      vector.transfer_write %in, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+
+      %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write into _an alias_ of the %mem Op that is used by the
+// original xfer pair.
+
+// CHECK-LABEL:   func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [1, 1] [1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             vector.transfer_write %[[IN]], %[[SV]][%[[C0]], %[[C0]]] {{.*}} : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+// CHECK:             %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:             %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:           }
+
+  %sv = memref.subview %mem[0, 0][1, 1][1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+  scf.for %i = %lb to %ub step %step {
+      vector.transfer_write %in, %sv[%c0, %c0] : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+
+      %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func @hoist_vector_transfer_pairs(
 //  CHECK-SAME:   %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
 //  CHECK-SAME:   %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,

>From 36bbbdc464df57e43d15f5046c35a799e705eb0f Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 25 Jun 2025 10:20:22 +0100
Subject: [PATCH 2/3] fixup! [mlir][linalg] Prevent hoisting of transfer pairs
 in the presence of aliases

1. Relax the conditions in the case of `memref.assume_alignment`. This
   unblocks hoisting for examples like the one reported in
   https://github.com/llvm/llvm-project/issues/144825 (see the link in
   the top post: "detailed example pls refer to example").
2. When checking the source operand for the xfer Ops, we only need to
   look at either xfer_write or xfer_read (we already know that the
   source is identical). The corresponding logic has been simplified.
---
 .../Dialect/Linalg/Transforms/Hoisting.cpp    | 40 ++++++++++-------
 mlir/test/Dialect/Linalg/hoisting.mlir        | 45 +++++++++++++++++++
 2 files changed, 70 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 808925a934979..1a5cb322074aa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -307,32 +307,41 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
       //   3. no other operations in the loop access the same memref except
       //      for transfer_read/transfer_write accessing statically disjoint
       //      slices.
+
+      // Check 1.
       if (transferRead.getIndices() != transferWrite.getIndices() ||
           transferRead.getVectorType() != transferWrite.getVectorType() ||
           transferRead.getPermutationMap() != transferWrite.getPermutationMap())
         return WalkResult::advance();
 
-      // Check 2. for xfer_read
-      auto *source = transferRead.getBase().getDefiningOp();
-      if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
-        return WalkResult::advance();
-
+      // Check 2. Note, since both xfer Ops share the source, we only need to
+      // look at one of them.
       auto base = transferRead.getBase();
-      for (auto *user : base.getUsers())
-        if (isa_and_nonnull<ViewLikeOpInterface>(user))
+      auto *source = base.getDefiningOp();
+      if (source) {
+        // NOTE: We treat `memref.assume_alignment` as a special case:
+        //  1. If it has exactly two uses then these have to be the xfer Ops
+        //    being looked at.
+        //  2. Otherwise, there are other users that we should take into
+        //    account
+        // In the case of 1., it is safe to look past AssumeAlignmentOp,
+        // i.e. at the defining Op of the input MemRef, provided that:
+        //  * the original MemRef has only one use (i.e.
+        //  `memref.assume_alignment`)
+        if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
+          Value memPreAlignment = assume.getMemref();
+          if (base.hasNUses(2) && memPreAlignment.hasOneUse())
+            source = memPreAlignment.getDefiningOp();
+        }
+        if (isa_and_nonnull<ViewLikeOpInterface>(source))
           return WalkResult::advance();
+      }
 
-      // Check 2. for xfer_wrire
-      source = transferWrite.getBase().getDefiningOp();
-      if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
-        return WalkResult::advance();
-
-      base = transferWrite.getBase();
       for (auto *user : base.getUsers())
         if (isa_and_nonnull<ViewLikeOpInterface>(user))
           return WalkResult::advance();
 
-      // Check 1. + 3.
+      // Check 3.
       // TODO: may want to memoize this information for performance but it
       // likely gets invalidated often.
       DominanceInfo dom(loop);
@@ -371,7 +380,8 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
       // Hoist write after.
       transferWrite->moveAfter(loop);
 
-      // Rewrite `loop` with new yields by cloning and erase the original loop.
+      // Rewrite `loop` with new yields by cloning and erase the original
+      // loop.
       IRRewriter rewriter(transferRead.getContext());
       NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
                                      ArrayRef<BlockArgument> newBBArgs) {
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index fd5a3edfb743f..7c2cd8cc0c150 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -133,6 +133,51 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Similar as the example above, but the memory access is done via
+// memref.assume_alignment. Hoisting is safe as the only users of the
+// "allignment" Op are the xfer Ops within the loop that we want to hoist.
+
+// CHECK-LABEL:   func.func @hoist_basic_vector_xfer_pair_with_assume_align(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @hoist_basic_vector_xfer_pair_with_assume_align(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[AA:.*]] = memref.assume_alignment %[[MEM]], 4 : memref<?x?xf32>
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[AA]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:           %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]]  iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
+// CHECK:             %[[USE:.*]] = "some_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:           }
+// CHECK:           vector.transfer_write %[[SCF]], %[[AA]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+
+  %aa = memref.assume_alignment %mem, 4 : memref<?x?xf32>
+  scf.for %i = %lb to %ub step %step {
+      %r0 = vector.transfer_read %aa[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u0, %aa[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func @hoist_vector_transfer_pairs(
 //  CHECK-SAME:   %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
 //  CHECK-SAME:   %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,

>From a671f3d45d093806e335da06df4af5f17fdeca93 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 25 Jun 2025 11:00:23 +0100
Subject: [PATCH 3/3] fixup! fixup! [mlir][linalg] Prevent hoisting of transfer
 pairs in the presence of aliases

Make sure only uses inside the loop are counted
---
 mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 1a5cb322074aa..4b1fdd4a4a3bc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -330,7 +330,12 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
         //  `memref.assume_alignment`)
         if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
           Value memPreAlignment = assume.getMemref();
-          if (base.hasNUses(2) && memPreAlignment.hasOneUse())
+          auto numInLoopUses =
+              llvm::count_if(base.getUses(), [&loop](OpOperand &use) {
+                return loop->isAncestor(use.getOwner());
+              });
+
+          if (numInLoopUses && memPreAlignment.hasOneUse())
             source = memPreAlignment.getDefiningOp();
         }
         if (isa_and_nonnull<ViewLikeOpInterface>(source))



More information about the Mlir-commits mailing list