[Mlir-commits] [mlir] [mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases (PR #145235)
Andrzej Warzyński
llvmlistbot at llvm.org
Sun Jun 22 06:54:01 PDT 2025
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/145235
This patch adds additional checks to the hoisting logic to prevent
hoisting of `vector.transfer_read`/`vector.transfer_write` pairs when
the underlying `memref` has users that introduce aliases via operations
implementing `ViewLikeOpInterface`.
Note: This may conservatively block some valid hoisting opportunities
and could impact performance. However, as demonstrated by the included
tests, the current behavior is too permissive and can lead to incorrect
transformations.
If this change prevents hoisting in cases that are provably safe, please
share a minimal repro — I’d be happy to explore ways to relax the check.
>From f7218d71dd69f63b14e6fb5d3da06228982b7423 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 21 Jun 2025 15:09:13 +0100
Subject: [PATCH] [mlir][linalg] Prevent hoisting of transfer pairs in the
presence of aliases
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This patch adds additional checks to the hoisting logic to prevent
hoisting of `vector.transfer_read`/`vector.transfer_write` pairs when
the underlying `memref` has users that introduce aliases via operations
implementing `ViewLikeOpInterface`.
Note: This may conservatively block some valid hoisting opportunities
and could impact performance. However, as demonstrated by the included
tests, the current behavior is too permissive and can lead to incorrect
transformations.
If this change prevents hoisting in cases that are provably safe, please
share a minimal repro — I’d be happy to explore ways to relax the check.
---
.../Dialect/Linalg/Transforms/Hoisting.cpp | 15 +-
mlir/test/Dialect/Linalg/hoisting.mlir | 133 ++++++++++++++++++
2 files changed, 147 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 707b63ff9335b..808925a934979 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -303,7 +303,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
// 1. indices, vector type and permutation map are the same (i.e., the
// transfer_read/transfer_write ops are matching),
// 2. source operands for transfer.{read|write} do not originate from
- // Ops implementing ViewLikeOpInterface.
+ // nor have users that are Ops implementing ViewLikeOpInterface.
// 3. no other operations in the loop access the same memref except
// for transfer_read/transfer_write accessing statically disjoint
// slices.
@@ -312,14 +312,27 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
return WalkResult::advance();
+ // Check 2. for xfer_read
auto *source = transferRead.getBase().getDefiningOp();
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
return WalkResult::advance();
+ auto base = transferRead.getBase();
+ for (auto *user : base.getUsers())
+ if (isa_and_nonnull<ViewLikeOpInterface>(user))
+ return WalkResult::advance();
+
+ // Check 2. for xfer_wrire
source = transferWrite.getBase().getDefiningOp();
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
return WalkResult::advance();
+ base = transferWrite.getBase();
+ for (auto *user : base.getUsers())
+ if (isa_and_nonnull<ViewLikeOpInterface>(user))
+ return WalkResult::advance();
+
+ // Check 1. + 3.
// TODO: may want to memoize this information for performance but it
// likely gets invalidated often.
DominanceInfo dom(loop);
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 318edca73cce1..fd5a3edfb743f 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -1,5 +1,138 @@
// RUN: mlir-opt -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
+// The most basic example - hoisting is safe.
+
+// CHECK-LABEL: func.func @hoist_basic_vector_xfer_pair(
+// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) {
+func.func @hoist_basic_vector_xfer_pair(
+ %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
+// CHECK: %[[VAL_6:.*]] = "some_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK: scf.yield %[[VAL_6]] : vector<1xf32>
+// CHECK: }
+// CHECK: vector.transfer_write %[[SCF]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+ scf.for %i = %lb to %ub step %step {
+ %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+ %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write inside the loop.
+
+// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+ %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: vector.transfer_write %[[IN]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK: %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK: }
+
+ scf.for %i = %lb to %ub step %step {
+ vector.transfer_write %in, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+
+ %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+ %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write into _an alias_ of the %mem Op that is used by the
+// original xfer pair.
+
+// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+ %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [1, 1] [1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: vector.transfer_write %[[IN]], %[[SV]][%[[C0]], %[[C0]]] {{.*}} : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK: %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK: }
+
+ %sv = memref.subview %mem[0, 0][1, 1][1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+ scf.for %i = %lb to %ub step %step {
+ vector.transfer_write %in, %sv[%c0, %c0] : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+
+ %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+ %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func @hoist_vector_transfer_pairs(
// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
More information about the Mlir-commits
mailing list