[Mlir-commits] [mlir] 3e4be55 - [MLIR][Affine] Improve sibling fusion - handle memrefs from memref defining nodes (#149641)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 23 21:54:59 PDT 2025


Author: Uday Bondhugula
Date: 2025-07-24T10:24:56+05:30
New Revision: 3e4be55c6b942fdcf1732b1a537e2477567ece54

URL: https://github.com/llvm/llvm-project/commit/3e4be55c6b942fdcf1732b1a537e2477567ece54
DIFF: https://github.com/llvm/llvm-project/commit/3e4be55c6b942fdcf1732b1a537e2477567ece54.diff

LOG: [MLIR][Affine] Improve sibling fusion - handle memrefs from memref defining nodes (#149641)

Improve sibling fusion - handle memrefs from memref defining nodes which
were not being considered.

Remove the unnecessary restriction from MDG memref edge iteration to
restrict to affine.for ops. Nodes in the MDG could be other ops as well.

Fixes: https://github.com/llvm/llvm-project/issues/61825

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/Analysis/Utils.cpp
    mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
    mlir/test/Dialect/Affine/loop-fusion-4.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 4739290bf6e4b..a89c1ae475b96 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -710,7 +710,7 @@ void MemRefDependenceGraph::clearNodeLoadAndStores(unsigned id) {
 void MemRefDependenceGraph::forEachMemRefInputEdge(
     unsigned id, const std::function<void(Edge)> &callback) {
   if (inEdges.count(id) > 0)
-    forEachMemRefEdge(inEdges[id], callback);
+    forEachMemRefEdge(inEdges.at(id), callback);
 }
 
 // Calls 'callback' for each output edge from node 'id' which carries a
@@ -718,7 +718,7 @@ void MemRefDependenceGraph::forEachMemRefInputEdge(
 void MemRefDependenceGraph::forEachMemRefOutputEdge(
     unsigned id, const std::function<void(Edge)> &callback) {
   if (outEdges.count(id) > 0)
-    forEachMemRefEdge(outEdges[id], callback);
+    forEachMemRefEdge(outEdges.at(id), callback);
 }
 
 // Calls 'callback' for each edge in 'edges' which carries a memref
@@ -730,9 +730,6 @@ void MemRefDependenceGraph::forEachMemRefEdge(
     if (!isa<MemRefType>(edge.value.getType()))
       continue;
     assert(nodes.count(edge.id) > 0);
-    // Skip if 'edge.id' is not a loop nest.
-    if (!isa<AffineForOp>(getNode(edge.id)->op))
-      continue;
     // Visit current input edge 'edge'.
     callback(edge);
   }

diff  --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 95848d0b67547..1d5a665bf6bb1 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -1473,9 +1473,11 @@ struct GreedyFusion {
     SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
     mdg->forEachMemRefInputEdge(
         dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
-          // Add 'inEdge' if it is a read-after-write dependence.
+          // Add 'inEdge' if it is a read-after-write dependence or an edge
+          // from a memref defining op (e.g. view-like op or alloc op).
           if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
-              mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
+              (mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0 ||
+               inEdge.value.getDefiningOp() == mdg->getNode(inEdge.id)->op))
             inEdges.push_back(inEdge);
         });
 

diff  --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index b059b5a98405d..04c8c3ee809a1 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -743,3 +743,31 @@ module {
     return
   }
 }
+
+// SIBLING-MAXIMAL-LABEL: memref_cast_reused
+func.func @memref_cast_reused(%arg: memref<*xf32>) {
+  %alloc = memref.cast %arg : memref<*xf32> to memref<10xf32>
+  %alloc_0 = memref.alloc() : memref<10xf32>
+  %alloc_1 = memref.alloc() : memref<10xf32>
+  %cst = arith.constant 0.000000e+00 : f32
+  %cst_2 = arith.constant 1.000000e+00 : f32
+  affine.for %arg0 = 0 to 10 {
+    %0 = affine.load %alloc[%arg0] : memref<10xf32>
+    %1 = arith.addf %0, %cst_2 : f32
+    affine.store %1, %alloc_0[%arg0] : memref<10xf32>
+  }
+  affine.for %arg0 = 0 to 10 {
+    %0 = affine.load %alloc[%arg0] : memref<10xf32>
+    %1 = affine.load %alloc_1[0] : memref<10xf32>
+    %2 = arith.addf %0, %1 : f32
+    affine.store %2, %alloc_1[0] : memref<10xf32>
+  }
+  // SIBLING-MAXIMAL:      affine.for %{{.*}} = 0 to 10
+  // SIBLING-MAXIMAL:        addf
+  // SIBLING-MAXIMAL-NEXT:   affine.store
+  // SIBLING-MAXIMAL-NEXT:   affine.load
+  // SIBLING-MAXIMAL-NEXT:   affine.load
+  // SIBLING-MAXIMAL-NEXT:   addf
+  // SIBLING-MAXIMAL-NEXT:   affine.store
+  return
+}


        


More information about the Mlir-commits mailing list