[Mlir-commits] [mlir] [MLIR][Affine] Improve sibling fusion - handle memrefs from memref defining nodes (PR #149641)

Uday Bondhugula llvmlistbot at llvm.org
Sat Jul 19 00:00:05 PDT 2025


https://github.com/bondhugula created https://github.com/llvm/llvm-project/pull/149641

Improve sibling fusion - handle memrefs from memref defining nodes which were not being considered.

Remove the unnecessary restriction from MDG memref edge iteration to restrict to affine.for ops. Nodes in the MDG could be other ops as well.

Fixes: https://github.com/llvm/llvm-project/issues/61825

>From 441430ea7a57e7108f1a2387a55cd3b39dad55ca Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <uday at polymagelabs.com>
Date: Fri, 18 Jul 2025 06:24:31 +0530
Subject: [PATCH] [MLIR][Affine] Improve sibling fusion - handle memrefs from
 memref defining nodes

Improve sibling fusion - handle memrefs from memref defining nodes
which were not being considered.

Remove the unnecessary restriction from MDG memref edge iteration to
restrict to affine.for ops. Nodes in the MDG could be other ops as well.

Fixes: https://github.com/llvm/llvm-project/issues/61825
---
 mlir/lib/Dialect/Affine/Analysis/Utils.cpp    |  7 ++---
 .../Dialect/Affine/Transforms/LoopFusion.cpp  |  6 ++--
 mlir/test/Dialect/Affine/loop-fusion-4.mlir   | 28 +++++++++++++++++++
 3 files changed, 34 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 4739290bf6e4b..a89c1ae475b96 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -710,7 +710,7 @@ void MemRefDependenceGraph::clearNodeLoadAndStores(unsigned id) {
 void MemRefDependenceGraph::forEachMemRefInputEdge(
     unsigned id, const std::function<void(Edge)> &callback) {
   if (inEdges.count(id) > 0)
-    forEachMemRefEdge(inEdges[id], callback);
+    forEachMemRefEdge(inEdges.at(id), callback);
 }
 
 // Calls 'callback' for each output edge from node 'id' which carries a
@@ -718,7 +718,7 @@ void MemRefDependenceGraph::forEachMemRefInputEdge(
 void MemRefDependenceGraph::forEachMemRefOutputEdge(
     unsigned id, const std::function<void(Edge)> &callback) {
   if (outEdges.count(id) > 0)
-    forEachMemRefEdge(outEdges[id], callback);
+    forEachMemRefEdge(outEdges.at(id), callback);
 }
 
 // Calls 'callback' for each edge in 'edges' which carries a memref
@@ -730,9 +730,6 @@ void MemRefDependenceGraph::forEachMemRefEdge(
     if (!isa<MemRefType>(edge.value.getType()))
       continue;
     assert(nodes.count(edge.id) > 0);
-    // Skip if 'edge.id' is not a loop nest.
-    if (!isa<AffineForOp>(getNode(edge.id)->op))
-      continue;
     // Visit current input edge 'edge'.
     callback(edge);
   }
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 95848d0b67547..1d5a665bf6bb1 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -1473,9 +1473,11 @@ struct GreedyFusion {
     SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
     mdg->forEachMemRefInputEdge(
         dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
-          // Add 'inEdge' if it is a read-after-write dependence.
+          // Add 'inEdge' if it is a read-after-write dependence or an edge
+          // from a memref defining op (e.g. view-like op or alloc op).
           if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
-              mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
+              (mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0 ||
+               inEdge.value.getDefiningOp() == mdg->getNode(inEdge.id)->op))
             inEdges.push_back(inEdge);
         });
 
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index b059b5a98405d..04c8c3ee809a1 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -743,3 +743,31 @@ module {
     return
   }
 }
+
+// SIBLING-MAXIMAL-LABEL: memref_cast_reused
+func.func @memref_cast_reused(%arg: memref<*xf32>) {
+  %alloc = memref.cast %arg : memref<*xf32> to memref<10xf32>
+  %alloc_0 = memref.alloc() : memref<10xf32>
+  %alloc_1 = memref.alloc() : memref<10xf32>
+  %cst = arith.constant 0.000000e+00 : f32
+  %cst_2 = arith.constant 1.000000e+00 : f32
+  affine.for %arg0 = 0 to 10 {
+    %0 = affine.load %alloc[%arg0] : memref<10xf32>
+    %1 = arith.addf %0, %cst_2 : f32
+    affine.store %1, %alloc_0[%arg0] : memref<10xf32>
+  }
+  affine.for %arg0 = 0 to 10 {
+    %0 = affine.load %alloc[%arg0] : memref<10xf32>
+    %1 = affine.load %alloc_1[0] : memref<10xf32>
+    %2 = arith.addf %0, %1 : f32
+    affine.store %2, %alloc_1[0] : memref<10xf32>
+  }
+  // SIBLING-MAXIMAL:      affine.for %{{.*}} = 0 to 10
+  // SIBLING-MAXIMAL:        addf
+  // SIBLING-MAXIMAL-NEXT:   affine.store
+  // SIBLING-MAXIMAL-NEXT:   affine.load
+  // SIBLING-MAXIMAL-NEXT:   affine.load
+  // SIBLING-MAXIMAL-NEXT:   addf
+  // SIBLING-MAXIMAL-NEXT:   affine.store
+  return
+}



More information about the Mlir-commits mailing list