[Mlir-commits] [mlir] [MLIR][Affine] Fix memref replacement in affine-data-copy-generate (PR #139016)
Uday Bondhugula
llvmlistbot at llvm.org
Wed May 7 21:06:58 PDT 2025
https://github.com/bondhugula created https://github.com/llvm/llvm-project/pull/139016
Fixes: https://github.com/llvm/llvm-project/issues/130257
Fix affine-data-copy-generate in certain cases that involved users in multiple blocks. Perform the memref replacement correctly during copy generation.
Improve/clean up memref affine use replacement API. Instead of supporting dominance and post dominance filters (which aren't adequate in most cases) and computing dominance info expensively each time in RAMUW, provide a user filter callback, i.e., force users to compute dominance if needed.
>From 5e0d8c55403a652f4ca12aac983b1ecec3c1dc58 Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <uday at polymagelabs.com>
Date: Thu, 8 May 2025 04:10:16 +0530
Subject: [PATCH] [MLIR][Affine] Fix memref replacement in
affine-data-copy-generate
Fixes: https://github.com/llvm/llvm-project/issues/130257
Fix affine-data-copy-generate in certain cases that involved users in
multiple blocks. Perform the memref replacement correctly during copy
generation.
Improve/clean up memref affine use replacement API. Instead of
supporting dominance and post dominance filters (which aren't adequate
in most cases) and computing dominance info expensively each time in
RAMUW, provide a user filter callback, i.e., force users to compute
dominance if needed.
---
mlir/include/mlir/Dialect/Affine/Utils.h | 14 ++---
.../Dialect/Affine/Transforms/LoopFusion.cpp | 7 ++-
.../Transforms/PipelineDataTransfer.cpp | 17 +++---
mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp | 22 +++++---
mlir/lib/Dialect/Affine/Utils/Utils.cpp | 55 ++++++++-----------
.../MemRef/Transforms/NormalizeMemRefs.cpp | 9 +--
.../test/Dialect/Affine/affine-data-copy.mlir | 48 ++++++++++++++++
7 files changed, 109 insertions(+), 63 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index ae5a68a6be157..ac11f5a7c24c7 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -198,10 +198,9 @@ AffineExpr substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min,
/// of its input list. `indexRemap`'s dimensional inputs are expected to
/// correspond to memref's indices, and its symbolic inputs if any should be
/// provided in `symbolOperands`.
-///
-/// `domOpFilter`, if non-null, restricts the replacement to only those
-/// operations that are dominated by the former; similarly, `postDomOpFilter`
-/// restricts replacement to only those operations that are postdominated by it.
+//
+/// If `userFilterFn` is specified, restrict replacement to only those users
+/// that pass the specified filter (i.e., the filter returns true).
///
/// 'allowNonDereferencingOps', if set, allows replacement of non-dereferencing
/// uses of a memref without any requirement for access index rewrites as long
@@ -224,13 +223,14 @@ AffineExpr substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min,
// d1, d2) -> (d0 - d1, d2), and %ii will be the extra operand. Without any
// extra operands, note that 'indexRemap' would just be applied to existing
// indices (%i, %j).
+//
// TODO: allow extraIndices to be added at any position.
LogicalResult replaceAllMemRefUsesWith(
Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices = {},
AffineMap indexRemap = AffineMap(), ArrayRef<Value> extraOperands = {},
- ArrayRef<Value> symbolOperands = {}, Operation *domOpFilter = nullptr,
- Operation *postDomOpFilter = nullptr, bool allowNonDereferencingOps = false,
- bool replaceInDeallocOp = false);
+ ArrayRef<Value> symbolOperands = {},
+ llvm::function_ref<bool(Operation *)> userFilterFn = nullptr,
+ bool allowNonDereferencingOps = false, bool replaceInDeallocOp = false);
/// Performs the same replacement as the other version above but only for the
/// dereferencing uses of `oldMemRef` in `op`, except in cases where
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 4b4eb9ce37b4c..da05dec6e4af3 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -445,10 +445,15 @@ static Value createPrivateMemRef(AffineForOp forOp,
// Replace all users of 'oldMemRef' with 'newMemRef'.
Operation *domFilter =
getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, storeOps);
+ auto userFilterFn = [&](Operation *user) {
+ auto domInfo = std::make_unique<DominanceInfo>(
+ domFilter->getParentOfType<FunctionOpInterface>());
+ return domInfo->dominates(domFilter, user);
+ };
LogicalResult res = replaceAllMemRefUsesWith(
oldMemRef, newMemRef, /*extraIndices=*/{}, indexRemap,
/*extraOperands=*/outerIVs,
- /*symbolOperands=*/{}, domFilter);
+ /*symbolOperands=*/{}, userFilterFn);
assert(succeeded(res) &&
"replaceAllMemrefUsesWith should always succeed here");
(void)res;
diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
index 4be99aa197380..92cb7075005a3 100644
--- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
@@ -115,13 +115,16 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
// replaceAllMemRefUsesWith will succeed unless the forOp body has
// non-dereferencing uses of the memref (dealloc's are fine though).
- if (failed(replaceAllMemRefUsesWith(
- oldMemRef, newMemRef,
- /*extraIndices=*/{ivModTwoOp},
- /*indexRemap=*/AffineMap(),
- /*extraOperands=*/{},
- /*symbolOperands=*/{},
- /*domOpFilter=*/&*forOp.getBody()->begin()))) {
+ auto userFilterFn = [&](Operation *user) {
+ auto domInfo = std::make_unique<DominanceInfo>(
+ forOp->getParentOfType<FunctionOpInterface>());
+ return domInfo->dominates(&*forOp.getBody()->begin(), user);
+ };
+ if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef,
+ /*extraIndices=*/{ivModTwoOp},
+ /*indexRemap=*/AffineMap(),
+ /*extraOperands=*/{},
+ /*symbolOperands=*/{}, userFilterFn))) {
LLVM_DEBUG(
forOp.emitError("memref replacement for double buffering failed"));
ivModTwoOp.erase();
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 0d4ba3940c48e..8c2761596da13 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -1969,6 +1969,12 @@ static LogicalResult generateCopy(
if (begin == end)
return success();
+ // Record the last op in the block for which we are performing copy
+ // generation. We later do the memref replacement only in [begin, lastCopyOp]
+ // so that the original memref's used in the data movement code themselves
+ // don't get replaced.
+ Operation *lastCopyOp = end->getPrevNode();
+
// Is the copy out point at the end of the block where we are doing
// explicit copying.
bool isCopyOutAtEndOfBlock = (end == copyOutPlacementStart);
@@ -2145,12 +2151,6 @@ static LogicalResult generateCopy(
}
}
- // Record the last operation where we want the memref replacement to end. We
- // later do the memref replacement only in [begin, postDomFilter] so
- // that the original memref's used in the data movement code themselves don't
- // get replaced.
- auto postDomFilter = std::prev(end);
-
// Create fully composed affine maps for each memref.
auto memAffineMap = b.getMultiDimIdentityMap(memIndices.size());
fullyComposeAffineMapAndOperands(&memAffineMap, &memIndices);
@@ -2246,13 +2246,17 @@ static LogicalResult generateCopy(
if (!isBeginAtStartOfBlock)
prevOfBegin = std::prev(begin);
+ auto userFilterFn = [&](Operation *user) {
+ auto *ancestorUser = block->findAncestorOpInBlock(*user);
+ return ancestorUser && !ancestorUser->isBeforeInBlock(&*begin) &&
+ !lastCopyOp->isBeforeInBlock(ancestorUser);
+ };
+
// *Only* those uses within the range [begin, end) of 'block' are replaced.
(void)replaceAllMemRefUsesWith(memref, fastMemRef,
/*extraIndices=*/{}, indexRemap,
/*extraOperands=*/regionSymbols,
- /*symbolOperands=*/{},
- /*domOpFilter=*/&*begin,
- /*postDomOpFilter=*/&*postDomFilter);
+ /*symbolOperands=*/{}, userFilterFn);
*nBegin = isBeginAtStartOfBlock ? block->begin() : std::next(prevOfBegin);
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index cde8223107859..66b3f2a4f93a5 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1305,9 +1305,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
LogicalResult mlir::affine::replaceAllMemRefUsesWith(
Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices,
AffineMap indexRemap, ArrayRef<Value> extraOperands,
- ArrayRef<Value> symbolOperands, Operation *domOpFilter,
- Operation *postDomOpFilter, bool allowNonDereferencingOps,
- bool replaceInDeallocOp) {
+ ArrayRef<Value> symbolOperands,
+ llvm::function_ref<bool(Operation *)> userFilterFn,
+ bool allowNonDereferencingOps, bool replaceInDeallocOp) {
unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
(void)newMemRefRank; // unused in opt mode
unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
@@ -1328,61 +1328,52 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
std::unique_ptr<DominanceInfo> domInfo;
std::unique_ptr<PostDominanceInfo> postDomInfo;
- if (domOpFilter)
- domInfo = std::make_unique<DominanceInfo>(
- domOpFilter->getParentOfType<FunctionOpInterface>());
-
- if (postDomOpFilter)
- postDomInfo = std::make_unique<PostDominanceInfo>(
- postDomOpFilter->getParentOfType<FunctionOpInterface>());
// Walk all uses of old memref; collect ops to perform replacement. We use a
// DenseSet since an operation could potentially have multiple uses of a
// memref (although rare), and the replacement later is going to erase ops.
DenseSet<Operation *> opsToReplace;
- for (auto *op : oldMemRef.getUsers()) {
- // Skip this use if it's not dominated by domOpFilter.
- if (domOpFilter && !domInfo->dominates(domOpFilter, op))
- continue;
-
- // Skip this use if it's not post-dominated by postDomOpFilter.
- if (postDomOpFilter && !postDomInfo->postDominates(postDomOpFilter, op))
+ for (auto *user : oldMemRef.getUsers()) {
+ // Check if this user doesn't pass the filter.
+ if (userFilterFn && !userFilterFn(user))
continue;
// Skip dealloc's - no replacement is necessary, and a memref replacement
// at other uses doesn't hurt these dealloc's.
- if (hasSingleEffect<MemoryEffects::Free>(op, oldMemRef) &&
+ if (hasSingleEffect<MemoryEffects::Free>(user, oldMemRef) &&
!replaceInDeallocOp)
continue;
// Check if the memref was used in a non-dereferencing context. It is fine
// for the memref to be used in a non-dereferencing way outside of the
// region where this replacement is happening.
- if (!isa<AffineMapAccessInterface>(*op)) {
+ if (!isa<AffineMapAccessInterface>(*user)) {
if (!allowNonDereferencingOps) {
- LLVM_DEBUG(llvm::dbgs()
- << "Memref replacement failed: non-deferencing memref op: \n"
- << *op << '\n');
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "Memref replacement failed: non-deferencing memref user: \n"
+ << *user << '\n');
return failure();
}
// Non-dereferencing ops with the MemRefsNormalizable trait are
// supported for replacement.
- if (!op->hasTrait<OpTrait::MemRefsNormalizable>()) {
+ if (!user->hasTrait<OpTrait::MemRefsNormalizable>()) {
LLVM_DEBUG(llvm::dbgs() << "Memref replacement failed: use without a "
"memrefs normalizable trait: \n"
- << *op << '\n');
+ << *user << '\n');
return failure();
}
}
- // We'll first collect and then replace --- since replacement erases the op
- // that has the use, and that op could be postDomFilter or domFilter itself!
- opsToReplace.insert(op);
+ // We'll first collect and then replace --- since replacement erases the
+ // user that has the use, and that user could be postDomFilter or domFilter
+ // itself!
+ opsToReplace.insert(user);
}
- for (auto *op : opsToReplace) {
+ for (auto *user : opsToReplace) {
if (failed(replaceAllMemRefUsesWith(
- oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands,
+ oldMemRef, newMemRef, user, extraIndices, indexRemap, extraOperands,
symbolOperands, allowNonDereferencingOps)))
llvm_unreachable("memref replacement guaranteed to succeed here");
}
@@ -1763,8 +1754,7 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp allocOp) {
/*indexRemap=*/layoutMap,
/*extraOperands=*/{},
/*symbolOperands=*/symbolOperands,
- /*domOpFilter=*/nullptr,
- /*postDomOpFilter=*/nullptr,
+ /*userFilterFn=*/nullptr,
/*allowNonDereferencingOps=*/true))) {
// If it failed (due to escapes for example), bail out.
newAlloc.erase();
@@ -1854,8 +1844,7 @@ mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) {
/*indexRemap=*/oldLayoutMap,
/*extraOperands=*/{},
/*symbolOperands=*/oldStrides,
- /*domOpFilter=*/nullptr,
- /*postDomOpFilter=*/nullptr,
+ /*userFilterFn=*/nullptr,
/*allowNonDereferencingOps=*/true))) {
// If it failed (due to escapes for example), bail out.
newReinterpretCast.erase();
diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index b408962690810..d6fcb8d9f0501 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -297,8 +297,7 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
/*indexRemap=*/layoutMap,
/*extraOperands=*/{},
/*symbolOperands=*/{},
- /*domOpFilter=*/nullptr,
- /*postDomOpFilter=*/nullptr,
+ /*userFilterFn=*/nullptr,
/*allowNonDereferencingOps=*/true,
/*replaceInDeallocOp=*/true))) {
// If it failed (due to escapes for example), bail out.
@@ -407,8 +406,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
/*indexRemap=*/layoutMap,
/*extraOperands=*/{},
/*symbolOperands=*/{},
- /*domOpFilter=*/nullptr,
- /*postDomOpFilter=*/nullptr,
+ /*userFilterFn=*/nullptr,
/*allowNonDereferencingOps=*/true,
/*replaceInDeallocOp=*/true))) {
// If it failed (due to escapes for example), bail out. Removing the
@@ -457,8 +455,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
/*indexRemap=*/layoutMap,
/*extraOperands=*/{},
/*symbolOperands=*/{},
- /*domOpFilter=*/nullptr,
- /*postDomOpFilter=*/nullptr,
+ /*userFilterFn=*/nullptr,
/*allowNonDereferencingOps=*/true,
/*replaceInDeallocOp=*/true))) {
newOp->erase();
diff --git a/mlir/test/Dialect/Affine/affine-data-copy.mlir b/mlir/test/Dialect/Affine/affine-data-copy.mlir
index a1f0d952e7c63..a745271eb9ca8 100644
--- a/mlir/test/Dialect/Affine/affine-data-copy.mlir
+++ b/mlir/test/Dialect/Affine/affine-data-copy.mlir
@@ -447,3 +447,51 @@ func.func @memref_def_inside(%arg0: index) {
// LIMITED-MEM-NEXT: memref.dealloc %{{.*}} : memref<1xf32>
return
}
+
+// Test with uses across multiple blocks.
+
+memref.global "private" constant @__constant_1x2x1xi32_1 : memref<1x2x1xi32> = dense<0> {alignment = 64 : i64}
+
+// CHECK-LABEL: func @multiple_blocks
+func.func @multiple_blocks(%arg0: index) -> memref<1x2x1xi32> {
+ %c1_i32 = arith.constant 1 : i32
+ %c3_i32 = arith.constant 3 : i32
+ %0 = memref.get_global @__constant_1x2x1xi32_1 : memref<1x2x1xi32>
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x1xi32>
+ memref.copy %0, %alloc : memref<1x2x1xi32> to memref<1x2x1xi32>
+ cf.br ^bb1(%alloc : memref<1x2x1xi32>)
+^bb1(%1: memref<1x2x1xi32>): // 2 preds: ^bb0, ^bb2
+// CHECK: ^bb1(%[[MEM:.*]]: memref<1x2x1xi32>):
+ %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x2x1xi1>
+ // CHECK: %[[BUF:.*]] = memref.alloc() : memref<1x2x1xi32>
+ affine.for %arg1 = 0 to 1 {
+ affine.for %arg2 = 0 to 2 {
+ affine.for %arg3 = 0 to 1 {
+ // CHECK: affine.load %[[BUF]]
+ %3 = affine.load %1[%arg1, %arg2, %arg3] : memref<1x2x1xi32>
+ %4 = arith.cmpi slt, %3, %c3_i32 : i32
+ affine.store %4, %alloc_0[%arg1, %arg2, %arg3] : memref<1x2x1xi1>
+ }
+ }
+ }
+ // CHECK: memref.dealloc %[[BUF]]
+ %2 = memref.load %alloc_0[%arg0, %arg0, %arg0] : memref<1x2x1xi1>
+ cf.cond_br %2, ^bb2, ^bb3
+^bb2: // pred: ^bb1
+// CHECK: ^bb2
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x2x1xi32>
+ affine.for %arg1 = 0 to 1 {
+ affine.for %arg2 = 0 to 2 {
+ affine.for %arg3 = 0 to 1 {
+ // Ensure that this reference isn't replaced.
+ %3 = affine.load %1[%arg1, %arg2, %arg3] : memref<1x2x1xi32>
+ // CHECK: affine.load %[[MEM]]
+ %4 = arith.addi %3, %c1_i32 : i32
+ affine.store %4, %alloc_1[%arg1, %arg2, %arg3] : memref<1x2x1xi32>
+ }
+ }
+ }
+ cf.br ^bb1(%alloc_1 : memref<1x2x1xi32>)
+^bb3: // pred: ^bb1
+ return %1 : memref<1x2x1xi32>
+}
More information about the Mlir-commits
mailing list