[Mlir-commits] [mlir] [MLIR][Affine] Improve sibling fusion - handle memrefs from memref defining nodes (PR #149641)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jul 19 00:00:34 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Uday Bondhugula (bondhugula)
<details>
<summary>Changes</summary>
Improve sibling fusion - handle memrefs from memref defining nodes which were not being considered.
Remove the unnecessary restriction from MDG memref edge iteration to restrict to affine.for ops. Nodes in the MDG could be other ops as well.
Fixes: https://github.com/llvm/llvm-project/issues/61825
---
Full diff: https://github.com/llvm/llvm-project/pull/149641.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Affine/Analysis/Utils.cpp (+2-5)
- (modified) mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp (+4-2)
- (modified) mlir/test/Dialect/Affine/loop-fusion-4.mlir (+28)
``````````diff
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 4739290bf6e4b..a89c1ae475b96 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -710,7 +710,7 @@ void MemRefDependenceGraph::clearNodeLoadAndStores(unsigned id) {
void MemRefDependenceGraph::forEachMemRefInputEdge(
unsigned id, const std::function<void(Edge)> &callback) {
if (inEdges.count(id) > 0)
- forEachMemRefEdge(inEdges[id], callback);
+ forEachMemRefEdge(inEdges.at(id), callback);
}
// Calls 'callback' for each output edge from node 'id' which carries a
@@ -718,7 +718,7 @@ void MemRefDependenceGraph::forEachMemRefInputEdge(
void MemRefDependenceGraph::forEachMemRefOutputEdge(
unsigned id, const std::function<void(Edge)> &callback) {
if (outEdges.count(id) > 0)
- forEachMemRefEdge(outEdges[id], callback);
+ forEachMemRefEdge(outEdges.at(id), callback);
}
// Calls 'callback' for each edge in 'edges' which carries a memref
@@ -730,9 +730,6 @@ void MemRefDependenceGraph::forEachMemRefEdge(
if (!isa<MemRefType>(edge.value.getType()))
continue;
assert(nodes.count(edge.id) > 0);
- // Skip if 'edge.id' is not a loop nest.
- if (!isa<AffineForOp>(getNode(edge.id)->op))
- continue;
// Visit current input edge 'edge'.
callback(edge);
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 95848d0b67547..1d5a665bf6bb1 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -1473,9 +1473,11 @@ struct GreedyFusion {
SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
mdg->forEachMemRefInputEdge(
dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
- // Add 'inEdge' if it is a read-after-write dependence.
+ // Add 'inEdge' if it is a read-after-write dependence or an edge
+ // from a memref defining op (e.g. view-like op or alloc op).
if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
- mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
+ (mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0 ||
+ inEdge.value.getDefiningOp() == mdg->getNode(inEdge.id)->op))
inEdges.push_back(inEdge);
});
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index b059b5a98405d..04c8c3ee809a1 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -743,3 +743,31 @@ module {
return
}
}
+
+// SIBLING-MAXIMAL-LABEL: memref_cast_reused
+func.func @memref_cast_reused(%arg: memref<*xf32>) {
+ %alloc = memref.cast %arg : memref<*xf32> to memref<10xf32>
+ %alloc_0 = memref.alloc() : memref<10xf32>
+ %alloc_1 = memref.alloc() : memref<10xf32>
+ %cst = arith.constant 0.000000e+00 : f32
+ %cst_2 = arith.constant 1.000000e+00 : f32
+ affine.for %arg0 = 0 to 10 {
+ %0 = affine.load %alloc[%arg0] : memref<10xf32>
+ %1 = arith.addf %0, %cst_2 : f32
+ affine.store %1, %alloc_0[%arg0] : memref<10xf32>
+ }
+ affine.for %arg0 = 0 to 10 {
+ %0 = affine.load %alloc[%arg0] : memref<10xf32>
+ %1 = affine.load %alloc_1[0] : memref<10xf32>
+ %2 = arith.addf %0, %1 : f32
+ affine.store %2, %alloc_1[0] : memref<10xf32>
+ }
+ // SIBLING-MAXIMAL: affine.for %{{.*}} = 0 to 10
+ // SIBLING-MAXIMAL: addf
+ // SIBLING-MAXIMAL-NEXT: affine.store
+ // SIBLING-MAXIMAL-NEXT: affine.load
+ // SIBLING-MAXIMAL-NEXT: affine.load
+ // SIBLING-MAXIMAL-NEXT: addf
+ // SIBLING-MAXIMAL-NEXT: affine.store
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/149641
More information about the Mlir-commits
mailing list