[Mlir-commits] [mlir] 541f33e - [mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases (#145235)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 27 05:18:19 PDT 2025
Author: Andrzej Warzyński
Date: 2025-06-27T13:18:15+01:00
New Revision: 541f33e0751d60b33e75efe0cd436396f27b91ca
URL: https://github.com/llvm/llvm-project/commit/541f33e0751d60b33e75efe0cd436396f27b91ca
DIFF: https://github.com/llvm/llvm-project/commit/541f33e0751d60b33e75efe0cd436396f27b91ca.diff
LOG: [mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases (#145235)
This patch adds additional checks to the hoisting logic to prevent hoisting of
`vector.transfer_read` / `vector.transfer_write` pairs when the underlying
memref has users that introduce aliases via operations implementing
`ViewLikeOpInterface`.
Note: This may conservatively block some valid hoisting opportunities and could
affect performance. However, as demonstrated by the included tests, the current
logic is too permissive and can lead to incorrect transformations.
If this change prevents hoisting in cases that are provably safe, please share
a minimal repro - I'm happy to explore ways to relax the check.
Special treatment is given to `memref.assume_alignment`, mainly to accommodate
recent updates in:
* https://github.com/llvm/llvm-project/pull/139521
Note that such special casing does not scale and should generally be avoided.
The current hoisting logic lacks robust alias analysis. While better support
would require more work, the broader semantics of `memref.assume_alignment`
remain somewhat unclear. It's possible this op may eventually be replaced with
the "alignment" attribute added in:
* https://github.com/llvm/llvm-project/pull/144344
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
mlir/test/Dialect/Linalg/hoisting.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 707b63ff9335b..d833e04d60264 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -303,23 +303,51 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
// 1. indices, vector type and permutation map are the same (i.e., the
// transfer_read/transfer_write ops are matching),
// 2. source operands for transfer.{read|write} do not originate from
- // Ops implementing ViewLikeOpInterface.
+ // nor have users that are Ops implementing ViewLikeOpInterface.
// 3. no other operations in the loop access the same memref except
// for transfer_read/transfer_write accessing statically disjoint
// slices.
+
+ // Check 1.
if (transferRead.getIndices() != transferWrite.getIndices() ||
transferRead.getVectorType() != transferWrite.getVectorType() ||
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
return WalkResult::advance();
- auto *source = transferRead.getBase().getDefiningOp();
- if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
- return WalkResult::advance();
+ // Check 2. Note, since both xfer Ops share the source, we only need to
+ // look at one of them.
+ auto base = transferRead.getBase();
+ auto *source = base.getDefiningOp();
+ if (source) {
+ // NOTE: We treat `memref.assume_alignment` as a special case.
+ //
+ // The idea is that it is safe to look past AssumeAlignmemtOp (i.e.
+ // MemRef _before_ alignment) iff:
+ // 1. It has exactly two uses (these have to be the xfer Ops
+ // being looked at).
+ // 2. The original MemRef has only one use (i.e.
+ // AssumeAlignmentOp).
+ //
+ // Relaxing these conditions will most likely require proper alias
+ // analysis.
+ if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
+ Value memPreAlignment = assume.getMemref();
+ auto numInLoopUses =
+ llvm::count_if(base.getUses(), [&loop](OpOperand &use) {
+ return loop->isAncestor(use.getOwner());
+ });
+
+ if (numInLoopUses && memPreAlignment.hasOneUse())
+ source = memPreAlignment.getDefiningOp();
+ }
+ if (isa_and_nonnull<ViewLikeOpInterface>(source))
+ return WalkResult::advance();
+ }
- source = transferWrite.getBase().getDefiningOp();
- if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
+ if (llvm::any_of(base.getUsers(), llvm::IsaPred<ViewLikeOpInterface>))
return WalkResult::advance();
+ // Check 3.
// TODO: may want to memoize this information for performance but it
// likely gets invalidated often.
DominanceInfo dom(loop);
@@ -358,7 +386,8 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
// Hoist write after.
transferWrite->moveAfter(loop);
- // Rewrite `loop` with new yields by cloning and erase the original loop.
+ // Rewrite `loop` with new yields by cloning and erase the original
+ // loop.
IRRewriter rewriter(transferRead.getContext());
NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
ArrayRef<BlockArgument> newBBArgs) {
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 8be4e1b79c52c..aa0b97a4787fa 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -1,5 +1,234 @@
// RUN: mlir-opt -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
+///----------------------------------------------------------------------------------------
+/// Tests for vector.transfer_read + vector.transfer_write pairs
+///
+/// * Nested inside a single loop
+// * Indices are constant
+///----------------------------------------------------------------------------------------
+
+// The most basic example - hoisting is safe.
+
+// CHECK-LABEL: func.func @hoist_basic_vector_xfer_pair(
+// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) {
+func.func @hoist_basic_vector_xfer_pair(
+ %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
+// CHECK: %[[VAL_6:.*]] = "val_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK: scf.yield %[[VAL_6]] : vector<1xf32>
+// CHECK: }
+// CHECK: vector.transfer_write %[[SCF]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+ scf.for %i = %lb to %ub step %step {
+ %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+ %u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write inside the loop.
+
+// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+ %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: vector.transfer_write %[[IN]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK: }
+
+ scf.for %i = %lb to %ub step %step {
+ vector.transfer_write %in, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+
+ %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+ %u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write into _an alias_ of the %mem Op that is used by the
+// original xfer pair.
+
+// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+ %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [1, 1] [1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: vector.transfer_write %[[IN]], %[[SV]][%[[C0]], %[[C0]]] {{.*}} : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK: }
+
+ %sv = memref.subview %mem[0, 0][1, 1][1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+ scf.for %i = %lb to %ub step %step {
+ vector.transfer_write %in, %sv[%c0, %c0] : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+
+ %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+ %u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Similar as the example above, but the memory access is done via
+// memref.assume_alignment. Hoisting is safe as the only users of the
+// "allignment" Op are the xfer Ops within the loop that we want to hoist.
+
+// CHECK-LABEL: func.func @hoist_basic_vector_xfer_pair_with_assume_align(
+// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @hoist_basic_vector_xfer_pair_with_assume_align(
+ %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[AA:.*]] = memref.assume_alignment %[[MEM]], 4 : memref<?x?xf32>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[AA]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
+// CHECK: %[[USE:.*]] = "val_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK: }
+// CHECK: vector.transfer_write %[[SCF]], %[[AA]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+
+ %aa = memref.assume_alignment %mem, 4 : memref<?x?xf32>
+ scf.for %i = %lb to %ub step %step {
+ %r0 = vector.transfer_read %aa[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+ %u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ vector.transfer_write %u0, %aa[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is not safe due to extra memory
+// access inside the loop via the original memref.
+
+// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_with_assume_align(
+// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_with_assume_align(
+ %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[AA:.*]] = memref.assume_alignment %[[MEM]], 4 : memref<?x?xf32>
+// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[AA]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK: "mem_use"(%[[MEM]])
+// CHECK: vector.transfer_write %[[READ]], %[[AA]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK: }
+
+ %aa = memref.assume_alignment %mem, 4 : memref<?x?xf32>
+ scf.for %i = %lb to %ub step %step {
+ %r0 = vector.transfer_read %aa[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+ "mem_use"(%mem) : (memref<?x?xf32>) -> ()
+ vector.transfer_write %r0, %aa[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
///----------------------------------------------------------------------------------------
/// Tests for vector.transfer_read + vector.transfer_write pairs
///
More information about the Mlir-commits
mailing list