[Mlir-commits] [mlir] [MLIR][Affine] Fix private memref creation bug in affine fusion (PR #126028)
Uday Bondhugula
llvmlistbot at llvm.org
Thu Feb 6 01:21:13 PST 2025
https://github.com/bondhugula updated https://github.com/llvm/llvm-project/pull/126028
>From f275eda6a2303230cc5bfe8c2e09474d5b624496 Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <uday at polymagelabs.com>
Date: Thu, 6 Feb 2025 14:06:46 +0530
Subject: [PATCH] [MLIR][Affine] Fix private memref creation bug in affine
fusion
Fix private memref creation bug in affine fusion exposed in the case of
the same memref being loaded from/stored to in producer nest. Make the
private memref replacement sound.
Change affine fusion debug string to affine-fusion - more compact.
Fixes: https://github.com/llvm/llvm-project/issues/48703
---
.../mlir/Dialect/Affine/Analysis/Utils.h | 8 +++
mlir/lib/Dialect/Affine/Analysis/Utils.cpp | 38 ++++++++++
.../Dialect/Affine/Transforms/LoopFusion.cpp | 71 +++++++++++++++----
mlir/test/Dialect/Affine/loop-fusion-4.mlir | 29 ++++++++
4 files changed, 134 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index b1fbf4477428ca2..cff4983af17fc90 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -610,6 +610,14 @@ FailureOr<AffineValueMap>
simplifyConstrainedMinMaxOp(Operation *op,
FlatAffineValueConstraints constraints);
+/// Find the innermost common `Block` of `a` and `b` in the affine scope
+/// that `a` and `b` are part of. Return nullptr if they belong to different
+/// affine scopes. Also, return null if they do not have a common `Block`
+/// ancestor (for eg., when they are part of the `then` and `else` regions
+/// of an op that itself starts an affine scope.
+mlir::Block *findInnermostCommonBlockInScope(mlir::Operation *a,
+ mlir::Operation *b);
+
} // namespace affine
} // namespace mlir
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 9c0b5dbf52d299b..d6c62cdd613643e 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/Analysis/Utils.h"
+
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
@@ -2297,3 +2298,40 @@ FailureOr<AffineValueMap> mlir::affine::simplifyConstrainedMinMaxOp(
affine::canonicalizeMapAndOperands(&newMap, &newOperands);
return AffineValueMap(newMap, newOperands);
}
+
+Block *mlir::affine::findInnermostCommonBlockInScope(Operation *a,
+ Operation *b) {
+ Region *aScope = mlir::affine::getAffineScope(a);
+ Region *bScope = mlir::affine::getAffineScope(b);
+ if (aScope != bScope)
+ return nullptr;
+
+ // Get the block ancestry of `a` while stopping at the affine scope.
+ auto getBlockAncestry = [&](Operation *op,
+ SmallVectorImpl<Block *> &ancestry) {
+ Operation *curOp = op;
+ do {
+ ancestry.push_back(curOp->getBlock());
+ if (curOp->getParentRegion() == aScope)
+ break;
+ curOp = curOp->getParentOp();
+ } while (curOp);
+ assert(curOp && "can't reach root op without passing through affine scope");
+ std::reverse(ancestry.begin(), ancestry.end());
+ };
+
+ SmallVector<Block *, 4> aAncestors, bAncestors;
+ getBlockAncestry(a, aAncestors);
+ getBlockAncestry(b, bAncestors);
+ assert(!aAncestors.empty() && !bAncestors.empty() &&
+ "at least one Block ancestor expected");
+
+ Block *innermostCommonBlock = nullptr;
+ for (unsigned a = 0, b = 0, e = aAncestors.size(), f = bAncestors.size();
+ a < e && b < f; ++a, ++b) {
+ if (aAncestors[a] != bAncestors[b])
+ break;
+ innermostCommonBlock = aAncestors[a];
+ }
+ return innermostCommonBlock;
+}
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index c22ec213be95c84..0ea27df704d0694 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -41,7 +41,7 @@ namespace affine {
} // namespace affine
} // namespace mlir
-#define DEBUG_TYPE "affine-loop-fusion"
+#define DEBUG_TYPE "affine-fusion"
using namespace mlir;
using namespace mlir::affine;
@@ -237,29 +237,71 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
node->op = newRootForOp;
}
+/// Get the operation that should act as a dominance filter while replacing
+/// memref uses with a private memref for which `producerStores` and
+/// `sliceInsertionBlock` are provided. This effectively determines in what
+/// part of the IR we should be performing the replacement.
+static Operation *
+getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock,
+ ArrayRef<Operation *> producerStores) {
+ assert(!producerStores.empty() && "expected producer store");
+
+ // We first find the common block that contains the producer stores and
+ // the slice computation. The first ancestor among the ancestors of the
+ // producer stores in that common block is the dominance filter to use for
+ // replacement.
+ Block *commonBlock = nullptr;
+ // Find the common block of all relevant operations.
+ for (Operation *store : producerStores) {
+ if (!commonBlock)
+ commonBlock = findInnermostCommonBlockInScope(
+ store, &*sliceInsertionBlock->begin());
+ else
+ commonBlock =
+ findInnermostCommonBlockInScope(store, &*commonBlock->begin());
+ }
+ assert(commonBlock &&
+ "common block of producer stores and slice should exist");
+
+ // Find the first ancestor among the ancestors of `producerStores` in
+ // `commonBlock`.
+ Operation *firstAncestor = nullptr;
+ for (Operation *store : producerStores) {
+ Operation *ancestor = commonBlock->findAncestorOpInBlock(*store);
+ assert(ancestor && "producer store should be contained in common block");
+ firstAncestor = !firstAncestor || ancestor->isBeforeInBlock(firstAncestor)
+ ? ancestor
+ : firstAncestor;
+ }
+ return firstAncestor;
+}
+
// Creates and returns a private (single-user) memref for fused loop rooted
// at 'forOp', with (potentially reduced) memref size based on the
// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
// TODO: consider refactoring the common code from generateDma and
// this one.
-static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
+static Value createPrivateMemRef(AffineForOp forOp,
+ ArrayRef<Operation *> storeOps,
unsigned dstLoopDepth,
std::optional<unsigned> fastMemorySpace,
+ Block *sliceInsertionBlock,
uint64_t localBufSizeThreshold) {
- Operation *forInst = forOp.getOperation();
+ assert(!storeOps.empty() && "no source stores supplied");
+ Operation *srcStoreOp = storeOps[0];
// Create builder to insert alloc op just before 'forOp'.
- OpBuilder b(forInst);
+ OpBuilder b(forOp);
// Builder to create constants at the top level.
- OpBuilder top(forInst->getParentRegion());
+ OpBuilder top(forOp->getParentRegion());
// Create new memref type based on slice bounds.
- auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
+ auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp).getMemRef();
auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
unsigned rank = oldMemRefType.getRank();
// Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
- MemRefRegion region(srcStoreOpInst->getLoc());
- bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
+ MemRefRegion region(srcStoreOp->getLoc());
+ bool validRegion = succeeded(region.compute(srcStoreOp, dstLoopDepth));
(void)validRegion;
assert(validRegion && "unexpected memref region failure");
SmallVector<int64_t, 4> newShape;
@@ -332,11 +374,12 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
// Replace all users of 'oldMemRef' with 'newMemRef'.
+ Operation *domFilter =
+ getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, storeOps);
LogicalResult res =
replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
/*extraOperands=*/outerIVs,
- /*symbolOperands=*/{},
- /*domOpFilter=*/&*forOp.getBody()->begin());
+ /*symbolOperands=*/{}, domFilter);
assert(succeeded(res) &&
"replaceAllMemrefUsesWith should always succeed here");
(void)res;
@@ -944,6 +987,10 @@ struct GreedyFusion {
// Create private memrefs.
if (!privateMemrefs.empty()) {
+ // Note the block into which fusion was performed. This can be used to
+ // place `alloc`s that create private memrefs.
+ Block *sliceInsertionBlock = bestSlice.insertPoint->getBlock();
+
// Gather stores for all the private-to-be memrefs.
DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
@@ -962,8 +1009,8 @@ struct GreedyFusion {
SmallVector<Operation *, 4> &storesForMemref =
memrefToStoresPair.second;
Value newMemRef = createPrivateMemRef(
- dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
- fastMemorySpace, localBufSizeThreshold);
+ dstAffineForOp, storesForMemref, bestDstLoopDepth,
+ fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
// Create new node in dependence graph for 'newMemRef' alloc op.
unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
// Add edge from 'newMemRef' node to dstNode.
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index ea144f73bb21c6d..1241a46fb389419 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -285,3 +285,32 @@ module {
spirv.ReturnValue %3 : !spirv.array<8192 x f32>
}
}
+
+// -----
+
+// PRODUCER-CONSUMER-LABEL: func @same_memref_load_store
+func.func @same_memref_load_store(%producer : memref<32xf32>, %consumer: memref<16xf32>){
+ %cst = arith.constant 2.000000e+00 : f32
+ // Source isn't removed.
+ // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32
+ affine.for %arg3 = 0 to 32 {
+ %0 = affine.load %producer[%arg3] : memref<32xf32>
+ %2 = arith.mulf %0, %cst : f32
+ affine.store %2, %producer[%arg3] : memref<32xf32>
+ }
+ affine.for %arg3 = 0 to 16 {
+ %0 = affine.load %producer[%arg3] : memref<32xf32>
+ %2 = arith.addf %0, %cst : f32
+ affine.store %2, %consumer[%arg3] : memref<16xf32>
+ }
+ // Fused nest.
+ // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 16
+ // PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<32xf32>
+ // PRODUCER-CONSUMER-NEXT: arith.mulf
+ // PRODUCER-CONSUMER-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+ // PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
+ // PRODUCER-CONSUMER-NEXT: arith.addf
+ // PRODUCER-CONSUMER-NEXT: affine.store
+ // PRODUCER-CONSUMER-NEXT: }
+ return
+}
More information about the Mlir-commits
mailing list