[Mlir-commits] [mlir] b850ce4 - [MLIR][Affine] Fix private memref creation bug in affine fusion (#126028)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 7 19:05:15 PST 2025
Author: Uday Bondhugula
Date: 2025-02-08T08:35:10+05:30
New Revision: b850ce41db1e90cb2573ab5880da1d05de7828fd
URL: https://github.com/llvm/llvm-project/commit/b850ce41db1e90cb2573ab5880da1d05de7828fd
DIFF: https://github.com/llvm/llvm-project/commit/b850ce41db1e90cb2573ab5880da1d05de7828fd.diff
LOG: [MLIR][Affine] Fix private memref creation bug in affine fusion (#126028)
Fix private memref creation bug in affine fusion exposed in the case of
the same memref being loaded from/stored to in producer nest. Make the
private memref replacement sound.
Change affine fusion debug string to affine-fusion - more compact.
Fixes: https://github.com/llvm/llvm-project/issues/48703
Added:
Modified:
mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
mlir/lib/Dialect/Affine/Analysis/Utils.cpp
mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
mlir/test/Dialect/Affine/loop-fusion-4.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index b1fbf4477428ca2..7164ade6ea53a60 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -610,6 +610,14 @@ FailureOr<AffineValueMap>
simplifyConstrainedMinMaxOp(Operation *op,
FlatAffineValueConstraints constraints);
+/// Find the innermost common `Block` of `a` and `b` in the affine scope
+/// that `a` and `b` are part of. Return nullptr if they belong to
diff erent
+/// affine scopes. Also, return nullptr if they do not have a common `Block`
+/// ancestor (for eg., when they are part of the `then` and `else` regions
+/// of an op that itself starts an affine scope.
+mlir::Block *findInnermostCommonBlockInScope(mlir::Operation *a,
+ mlir::Operation *b);
+
} // namespace affine
} // namespace mlir
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 9c0b5dbf52d299b..10de0d04cbea640 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/Analysis/Utils.h"
+
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
@@ -2297,3 +2298,41 @@ FailureOr<AffineValueMap> mlir::affine::simplifyConstrainedMinMaxOp(
affine::canonicalizeMapAndOperands(&newMap, &newOperands);
return AffineValueMap(newMap, newOperands);
}
+
+Block *mlir::affine::findInnermostCommonBlockInScope(Operation *a,
+ Operation *b) {
+ Region *aScope = mlir::affine::getAffineScope(a);
+ Region *bScope = mlir::affine::getAffineScope(b);
+ if (aScope != bScope)
+ return nullptr;
+
+ // Get the block ancestry of `op` while stopping at the affine scope `aScope`
+ // and store them in `ancestry`.
+ auto getBlockAncestry = [&](Operation *op,
+ SmallVectorImpl<Block *> &ancestry) {
+ Operation *curOp = op;
+ do {
+ ancestry.push_back(curOp->getBlock());
+ if (curOp->getParentRegion() == aScope)
+ break;
+ curOp = curOp->getParentOp();
+ } while (curOp);
+ assert(curOp && "can't reach root op without passing through affine scope");
+ std::reverse(ancestry.begin(), ancestry.end());
+ };
+
+ SmallVector<Block *, 4> aAncestors, bAncestors;
+ getBlockAncestry(a, aAncestors);
+ getBlockAncestry(b, bAncestors);
+ assert(!aAncestors.empty() && !bAncestors.empty() &&
+ "at least one Block ancestor expected");
+
+ Block *innermostCommonBlock = nullptr;
+ for (unsigned a = 0, b = 0, e = aAncestors.size(), f = bAncestors.size();
+ a < e && b < f; ++a, ++b) {
+ if (aAncestors[a] != bAncestors[b])
+ break;
+ innermostCommonBlock = aAncestors[a];
+ }
+ return innermostCommonBlock;
+}
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index c22ec213be95c84..fe6cf0f434cb7eb 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -41,7 +41,7 @@ namespace affine {
} // namespace affine
} // namespace mlir
-#define DEBUG_TYPE "affine-loop-fusion"
+#define DEBUG_TYPE "affine-fusion"
using namespace mlir;
using namespace mlir::affine;
@@ -237,29 +237,67 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
node->op = newRootForOp;
}
-// Creates and returns a private (single-user) memref for fused loop rooted
-// at 'forOp', with (potentially reduced) memref size based on the
-// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
-// TODO: consider refactoring the common code from generateDma and
-// this one.
-static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
+/// Get the operation that should act as a dominance filter while replacing
+/// memref uses with a private memref for which `producerStores` and
+/// `sliceInsertionBlock` are provided. This effectively determines in what
+/// part of the IR we should be performing the replacement.
+static Operation *
+getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock,
+ ArrayRef<Operation *> producerStores) {
+ assert(!producerStores.empty() && "expected producer store");
+
+ // We first find the common block that contains the producer stores and
+ // the slice computation. The first ancestor among the ancestors of the
+ // producer stores in that common block is the dominance filter to use for
+ // replacement.
+ Block *commonBlock = nullptr;
+ // Find the common block of all relevant operations.
+ for (Operation *store : producerStores) {
+ Operation *otherOp =
+ !commonBlock ? &*sliceInsertionBlock->begin() : &*commonBlock->begin();
+ commonBlock = findInnermostCommonBlockInScope(store, otherOp);
+ }
+ assert(commonBlock &&
+ "common block of producer stores and slice should exist");
+
+ // Find the first ancestor among the ancestors of `producerStores` in
+ // `commonBlock`.
+ Operation *firstAncestor = nullptr;
+ for (Operation *store : producerStores) {
+ Operation *ancestor = commonBlock->findAncestorOpInBlock(*store);
+ assert(ancestor && "producer store should be contained in common block");
+ firstAncestor = !firstAncestor || ancestor->isBeforeInBlock(firstAncestor)
+ ? ancestor
+ : firstAncestor;
+ }
+ return firstAncestor;
+}
+
+// Creates and returns a private (single-user) memref for fused loop rooted at
+// 'forOp', with (potentially reduced) memref size based on the memref region
+// written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
+// specifies the block in which the slice was/will be inserted.
+static Value createPrivateMemRef(AffineForOp forOp,
+ ArrayRef<Operation *> storeOps,
unsigned dstLoopDepth,
std::optional<unsigned> fastMemorySpace,
+ Block *sliceInsertionBlock,
uint64_t localBufSizeThreshold) {
- Operation *forInst = forOp.getOperation();
+ assert(!storeOps.empty() && "no source stores supplied");
+ Operation *srcStoreOp = storeOps[0];
// Create builder to insert alloc op just before 'forOp'.
- OpBuilder b(forInst);
+ OpBuilder b(forOp);
// Builder to create constants at the top level.
- OpBuilder top(forInst->getParentRegion());
+ OpBuilder top(forOp->getParentRegion());
// Create new memref type based on slice bounds.
- auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
+ auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp).getMemRef();
auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
unsigned rank = oldMemRefType.getRank();
// Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
- MemRefRegion region(srcStoreOpInst->getLoc());
- bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
+ MemRefRegion region(srcStoreOp->getLoc());
+ bool validRegion = succeeded(region.compute(srcStoreOp, dstLoopDepth));
(void)validRegion;
assert(validRegion && "unexpected memref region failure");
SmallVector<int64_t, 4> newShape;
@@ -332,11 +370,12 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
// Replace all users of 'oldMemRef' with 'newMemRef'.
- LogicalResult res =
- replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
- /*extraOperands=*/outerIVs,
- /*symbolOperands=*/{},
- /*domOpFilter=*/&*forOp.getBody()->begin());
+ Operation *domFilter =
+ getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, storeOps);
+ LogicalResult res = replaceAllMemRefUsesWith(
+ oldMemRef, newMemRef, /*extraIndices=*/{}, indexRemap,
+ /*extraOperands=*/outerIVs,
+ /*symbolOperands=*/{}, domFilter);
assert(succeeded(res) &&
"replaceAllMemrefUsesWith should always succeed here");
(void)res;
@@ -944,6 +983,10 @@ struct GreedyFusion {
// Create private memrefs.
if (!privateMemrefs.empty()) {
+ // Note the block into which fusion was performed. This can be used to
+ // place `alloc`s that create private memrefs.
+ Block *sliceInsertionBlock = bestSlice.insertPoint->getBlock();
+
// Gather stores for all the private-to-be memrefs.
DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
@@ -962,8 +1005,8 @@ struct GreedyFusion {
SmallVector<Operation *, 4> &storesForMemref =
memrefToStoresPair.second;
Value newMemRef = createPrivateMemRef(
- dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
- fastMemorySpace, localBufSizeThreshold);
+ dstAffineForOp, storesForMemref, bestDstLoopDepth,
+ fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
// Create new node in dependence graph for 'newMemRef' alloc op.
unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
// Add edge from 'newMemRef' node to dstNode.
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index ea144f73bb21c6d..2830235431c7646 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -285,3 +285,63 @@ module {
spirv.ReturnValue %3 : !spirv.array<8192 x f32>
}
}
+
+// -----
+
+// PRODUCER-CONSUMER-LABEL: func @same_memref_load_store
+func.func @same_memref_load_store(%producer : memref<32xf32>, %consumer: memref<16xf32>){
+ %cst = arith.constant 2.000000e+00 : f32
+ // Source isn't removed.
+ // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32
+ affine.for %arg3 = 0 to 32 {
+ %0 = affine.load %producer[%arg3] : memref<32xf32>
+ %2 = arith.mulf %0, %cst : f32
+ affine.store %2, %producer[%arg3] : memref<32xf32>
+ }
+ affine.for %arg3 = 0 to 16 {
+ %0 = affine.load %producer[%arg3] : memref<32xf32>
+ %2 = arith.addf %0, %cst : f32
+ affine.store %2, %consumer[%arg3] : memref<16xf32>
+ }
+ // Fused nest.
+ // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 16
+ // PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<32xf32>
+ // PRODUCER-CONSUMER-NEXT: arith.mulf
+ // PRODUCER-CONSUMER-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+ // PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
+ // PRODUCER-CONSUMER-NEXT: arith.addf
+ // PRODUCER-CONSUMER-NEXT: affine.store
+ // PRODUCER-CONSUMER-NEXT: }
+ return
+}
+
+// PRODUCER-CONSUMER-LABEL: func @same_memref_load_multiple_stores
+func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %producer_2 : memref<32xf32>, %consumer: memref<16xf32>){
+ %cst = arith.constant 2.000000e+00 : f32
+ // Source isn't removed.
+ // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32
+ affine.for %arg3 = 0 to 32 {
+ %0 = affine.load %producer[%arg3] : memref<32xf32>
+ %2 = arith.mulf %0, %cst : f32
+ affine.store %2, %producer[%arg3] : memref<32xf32>
+ affine.store %2, %producer_2[%arg3] : memref<32xf32>
+ }
+ affine.for %arg3 = 0 to 16 {
+ %0 = affine.load %producer[%arg3] : memref<32xf32>
+ %1 = affine.load %producer_2[%arg3] : memref<32xf32>
+ %2 = arith.addf %0, %1 : f32
+ affine.store %2, %consumer[%arg3] : memref<16xf32>
+ }
+ // Fused nest.
+ // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 16
+ // PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<32xf32>
+ // PRODUCER-CONSUMER-NEXT: arith.mulf
+ // PRODUCER-CONSUMER-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+ // PRODUCER-CONSUMER-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+ // PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
+ // PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
+ // PRODUCER-CONSUMER-NEXT: arith.addf
+ // PRODUCER-CONSUMER-NEXT: affine.store
+ // PRODUCER-CONSUMER-NEXT: }
+ return
+}
More information about the Mlir-commits
mailing list