[Mlir-commits] [mlir] [mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases (PR #145235)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jun 22 06:54:34 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

<details>
<summary>Changes</summary>

This patch adds additional checks to the hoisting logic to prevent
hoisting of `vector.transfer_read`/`vector.transfer_write` pairs when
the underlying `memref` has users that introduce aliases via operations
implementing `ViewLikeOpInterface`.

Note: This may conservatively block some valid hoisting opportunities
and could impact performance. However, as demonstrated by the included
tests, the current behavior is too permissive and can lead to incorrect
transformations.

If this change prevents hoisting in cases that are provably safe, please
share a minimal repro — I’d be happy to explore ways to relax the check.


---
Full diff: https://github.com/llvm/llvm-project/pull/145235.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp (+14-1) 
- (modified) mlir/test/Dialect/Linalg/hoisting.mlir (+133) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 707b63ff9335b..808925a934979 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -303,7 +303,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
       //   1. indices, vector type and permutation map are the same (i.e., the
       //      transfer_read/transfer_write ops are matching),
       //   2. source operands for transfer.{read|write} do not originate from
-      //      Ops implementing ViewLikeOpInterface.
+      //      nor have users that are Ops implementing ViewLikeOpInterface.
       //   3. no other operations in the loop access the same memref except
       //      for transfer_read/transfer_write accessing statically disjoint
       //      slices.
@@ -312,14 +312,27 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
           transferRead.getPermutationMap() != transferWrite.getPermutationMap())
         return WalkResult::advance();
 
+      // Check 2. for xfer_read
       auto *source = transferRead.getBase().getDefiningOp();
       if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
         return WalkResult::advance();
 
+      auto base = transferRead.getBase();
+      for (auto *user : base.getUsers())
+        if (isa_and_nonnull<ViewLikeOpInterface>(user))
+          return WalkResult::advance();
+
+      // Check 2. for xfer_wrire
       source = transferWrite.getBase().getDefiningOp();
       if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
         return WalkResult::advance();
 
+      base = transferWrite.getBase();
+      for (auto *user : base.getUsers())
+        if (isa_and_nonnull<ViewLikeOpInterface>(user))
+          return WalkResult::advance();
+
+      // Check 1. + 3.
       // TODO: may want to memoize this information for performance but it
       // likely gets invalidated often.
       DominanceInfo dom(loop);
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 318edca73cce1..fd5a3edfb743f 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -1,5 +1,138 @@
 // RUN: mlir-opt  -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
 
+// The most basic example - hoisting is safe.
+
+// CHECK-LABEL:   func.func @hoist_basic_vector_xfer_pair(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index) {
+func.func @hoist_basic_vector_xfer_pair(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:           %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
+// CHECK:             %[[VAL_6:.*]] = "some_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             scf.yield %[[VAL_6]] : vector<1xf32>
+// CHECK:           }
+// CHECK:           vector.transfer_write %[[SCF]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+  scf.for %i = %lb to %ub step %step {
+      %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write inside the loop.
+
+// CHECK-LABEL:   func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             vector.transfer_write %[[IN]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:             %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:             %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:           }
+
+  scf.for %i = %lb to %ub step %step {
+      vector.transfer_write %in, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+
+      %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write into _an alias_ of the %mem Op that is used by the
+// original xfer pair.
+
+// CHECK-LABEL:   func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [1, 1] [1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             vector.transfer_write %[[IN]], %[[SV]][%[[C0]], %[[C0]]] {{.*}} : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+// CHECK:             %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:             %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:           }
+
+  %sv = memref.subview %mem[0, 0][1, 1][1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+  scf.for %i = %lb to %ub step %step {
+      vector.transfer_write %in, %sv[%c0, %c0] : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+
+      %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func @hoist_vector_transfer_pairs(
 //  CHECK-SAME:   %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
 //  CHECK-SAME:   %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,

``````````

</details>


https://github.com/llvm/llvm-project/pull/145235


More information about the Mlir-commits mailing list