[Mlir-commits] [mlir] [MLIR][Affine] Fix private memref creation bug in affine fusion (PR #126028)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 6 01:17:09 PST 2025

llvmbot wrote:



Author: Uday Bondhugula (bondhugula)


Fix private memref creation bug in affine fusion exposed in the case of
the same memref being loaded from/stored to in producer nest. Make the
private memref replacement sound.

Change affine fusion debug string to affine-fusion - more compact.

Full diff: https://github.com/llvm/llvm-project/pull/126028.diff

4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Affine/Analysis/Utils.h (+8) 
- (modified) mlir/lib/Dialect/Affine/Analysis/Utils.cpp (+38) 
- (modified) mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp (+59-12) 
- (modified) mlir/test/Dialect/Affine/loop-fusion-4.mlir (+29) 

diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index b1fbf4477428ca2..cff4983af17fc90 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -610,6 +610,14 @@ FailureOr<AffineValueMap>
 simplifyConstrainedMinMaxOp(Operation *op,
                             FlatAffineValueConstraints constraints);
+/// Find the innermost common `Block` of `a` and `b` in the affine scope
+/// that `a` and `b` are part of. Return nullptr if they belong to different
+/// affine scopes. Also, return null if they do not have a common `Block`
+/// ancestor (for eg., when they are part of the `then` and `else` regions
+/// of an op that itself starts an affine scope.
+mlir::Block *findInnermostCommonBlockInScope(mlir::Operation *a,
+                                             mlir::Operation *b);
 } // namespace affine
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 9c0b5dbf52d299b..d6c62cdd613643e 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
@@ -2297,3 +2298,40 @@ FailureOr<AffineValueMap> mlir::affine::simplifyConstrainedMinMaxOp(
   affine::canonicalizeMapAndOperands(&newMap, &newOperands);
   return AffineValueMap(newMap, newOperands);
+Block *mlir::affine::findInnermostCommonBlockInScope(Operation *a,
+                                                     Operation *b) {
+  Region *aScope = mlir::affine::getAffineScope(a);
+  Region *bScope = mlir::affine::getAffineScope(b);
+  if (aScope != bScope)
+    return nullptr;
+  // Get the block ancestry of `a` while stopping at the affine scope.
+  auto getBlockAncestry = [&](Operation *op,
+                              SmallVectorImpl<Block *> &ancestry) {
+    Operation *curOp = op;
+    do {
+      ancestry.push_back(curOp->getBlock());
+      if (curOp->getParentRegion() == aScope)
+        break;
+      curOp = curOp->getParentOp();
+    } while (curOp);
+    assert(curOp && "can't reach root op without passing through affine scope");
+    std::reverse(ancestry.begin(), ancestry.end());
+  };
+  SmallVector<Block *, 4> aAncestors, bAncestors;
+  getBlockAncestry(a, aAncestors);
+  getBlockAncestry(b, bAncestors);
+  assert(!aAncestors.empty() && !bAncestors.empty() &&
+         "at least one Block ancestor expected");
+  Block *innermostCommonBlock = nullptr;
+  for (unsigned a = 0, b = 0, e = aAncestors.size(), f = bAncestors.size();
+       a < e && b < f; ++a, ++b) {
+    if (aAncestors[a] != bAncestors[b])
+      break;
+    innermostCommonBlock = aAncestors[a];
+  }
+  return innermostCommonBlock;
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index c22ec213be95c84..0ea27df704d0694 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -41,7 +41,7 @@ namespace affine {
 } // namespace affine
 } // namespace mlir
-#define DEBUG_TYPE "affine-loop-fusion"
+#define DEBUG_TYPE "affine-fusion"
 using namespace mlir;
 using namespace mlir::affine;
@@ -237,29 +237,71 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
   node->op = newRootForOp;
+/// Get the operation that should act as a dominance filter while replacing
+/// memref uses with a private memref for which `producerStores` and
+/// `sliceInsertionBlock` are provided. This effectively determines in what
+/// part of the IR we should be performing the replacement.
+static Operation *
+getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock,
+                                       ArrayRef<Operation *> producerStores) {
+  assert(!producerStores.empty() && "expected producer store");
+  // We first find the common block that contains the producer stores and
+  // the slice computation. The first ancestor among the ancestors of the
+  // producer stores in that common block is the dominance filter to use for
+  // replacement.
+  Block *commonBlock = nullptr;
+  // Find the common block of all relevant operations.
+  for (Operation *store : producerStores) {
+    if (!commonBlock)
+      commonBlock = findInnermostCommonBlockInScope(
+          store, &*sliceInsertionBlock->begin());
+    else
+      commonBlock =
+          findInnermostCommonBlockInScope(store, &*commonBlock->begin());
+  }
+  assert(commonBlock &&
+         "common block of producer stores and slice should exist");
+  // Find the first ancestor among the ancestors of `producerStores` in
+  // `commonBlock`.
+  Operation *firstAncestor = nullptr;
+  for (Operation *store : producerStores) {
+    Operation *ancestor = commonBlock->findAncestorOpInBlock(*store);
+    assert(ancestor && "producer store should be contained in common block");
+    firstAncestor = !firstAncestor || ancestor->isBeforeInBlock(firstAncestor)
+                        ? ancestor
+                        : firstAncestor;
+  }
+  return firstAncestor;
 // Creates and returns a private (single-user) memref for fused loop rooted
 // at 'forOp', with (potentially reduced) memref size based on the
 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
 // TODO: consider refactoring the common code from generateDma and
 // this one.
-static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
+static Value createPrivateMemRef(AffineForOp forOp,
+                                 ArrayRef<Operation *> storeOps,
                                  unsigned dstLoopDepth,
                                  std::optional<unsigned> fastMemorySpace,
+                                 Block *sliceInsertionBlock,
                                  uint64_t localBufSizeThreshold) {
-  Operation *forInst = forOp.getOperation();
+  assert(!storeOps.empty() && "no source stores supplied");
+  Operation *srcStoreOp = storeOps[0];
   // Create builder to insert alloc op just before 'forOp'.
-  OpBuilder b(forInst);
+  OpBuilder b(forOp);
   // Builder to create constants at the top level.
-  OpBuilder top(forInst->getParentRegion());
+  OpBuilder top(forOp->getParentRegion());
   // Create new memref type based on slice bounds.
-  auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
+  auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp).getMemRef();
   auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
   unsigned rank = oldMemRefType.getRank();
   // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
-  MemRefRegion region(srcStoreOpInst->getLoc());
-  bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
+  MemRefRegion region(srcStoreOp->getLoc());
+  bool validRegion = succeeded(region.compute(srcStoreOp, dstLoopDepth));
   assert(validRegion && "unexpected memref region failure");
   SmallVector<int64_t, 4> newShape;
@@ -332,11 +374,12 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
       AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
   // Replace all users of 'oldMemRef' with 'newMemRef'.
+  Operation *domFilter =
+      getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, storeOps);
   LogicalResult res =
       replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
-                               /*symbolOperands=*/{},
-                               /*domOpFilter=*/&*forOp.getBody()->begin());
+                               /*symbolOperands=*/{}, domFilter);
   assert(succeeded(res) &&
          "replaceAllMemrefUsesWith should always succeed here");
@@ -944,6 +987,10 @@ struct GreedyFusion {
         // Create private memrefs.
         if (!privateMemrefs.empty()) {
+          // Note the block into which fusion was performed. This can be used to
+          // place `alloc`s that create private memrefs.
+          Block *sliceInsertionBlock = bestSlice.insertPoint->getBlock();
           // Gather stores for all the private-to-be memrefs.
           DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
           dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
@@ -962,8 +1009,8 @@ struct GreedyFusion {
             SmallVector<Operation *, 4> &storesForMemref =
             Value newMemRef = createPrivateMemRef(
-                dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
-                fastMemorySpace, localBufSizeThreshold);
+                dstAffineForOp, storesForMemref, bestDstLoopDepth,
+                fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
             // Create new node in dependence graph for 'newMemRef' alloc op.
             unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
             // Add edge from 'newMemRef' node to dstNode.
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index ea144f73bb21c6d..1241a46fb389419 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -285,3 +285,32 @@ module {
     spirv.ReturnValue %3 : !spirv.array<8192 x f32>
+// -----
+// PRODUCER-CONSUMER-LABEL: func @same_memref_load_store
+func.func @same_memref_load_store(%producer : memref<32xf32>, %consumer: memref<16xf32>){
+  %cst = arith.constant 2.000000e+00 : f32
+  // Source isn't removed.
+  // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32
+  affine.for %arg3 = 0 to 32 {
+    %0 = affine.load %producer[%arg3] : memref<32xf32>
+    %2 = arith.mulf %0, %cst : f32
+    affine.store %2, %producer[%arg3] : memref<32xf32>
+  }
+  affine.for %arg3 = 0 to 16 {
+    %0 = affine.load %producer[%arg3] : memref<32xf32>
+    %2 = arith.addf %0, %cst : f32
+    affine.store %2, %consumer[%arg3] : memref<16xf32>
+  }
+  // Fused nest.
+  // PRODUCER-CONSUMER:      affine.for %{{.*}} = 0 to 16
+  // PRODUCER-CONSUMER-NEXT:   affine.load %{{.*}}[%{{.*}}] : memref<32xf32>
+  // PRODUCER-CONSUMER-NEXT:   arith.mulf
+  // PRODUCER-CONSUMER-NEXT:   affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+  // PRODUCER-CONSUMER-NEXT:   affine.load %{{.*}}[0] : memref<1xf32>
+  // PRODUCER-CONSUMER-NEXT:   arith.addf
+  // PRODUCER-CONSUMER-NEXT:   affine.store
+  return




More information about the Mlir-commits mailing list