[Mlir-commits] [mlir] [MLIR][Affine] Loop fusion in a block containing Linalg op (PR #129136)

Andrey Turetskiy llvmlistbot at llvm.org
Thu Feb 27 15:57:12 PST 2025


https://github.com/aturetsk updated https://github.com/llvm/llvm-project/pull/129136

>From 2336b0d6a0eb7112ae8ce423ab02f8c5ebc6aae6 Mon Sep 17 00:00:00 2001
From: Andrey Turetskiy <turetski at cadence.com>
Date: Wed, 26 Feb 2025 17:51:07 -0800
Subject: [PATCH] [MLIR][Affine] Loop fusion in a block containing Linalg op

Handle region-holding operators implementing Linalg interface in
MemRefDependenceGraph.
---
 mlir/lib/Dialect/Affine/Analysis/Utils.cpp  |  9 ++++---
 mlir/test/Dialect/Affine/loop-fusion-4.mlir | 30 +++++++++++++++++++++
 2 files changed, 36 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index ba6f045cff408..54a17e93fc996 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Interfaces/CallInterfaces.h"
@@ -252,6 +253,9 @@ bool MemRefDependenceGraph::init() {
   // Create graph nodes.
   DenseMap<Operation *, unsigned> forToNodeMap;
   for (Operation &op : block) {
+    bool hasUnsupportedRegion =
+        op.getNumRegions() != 0 &&
+        !isa<RegionBranchOpInterface, linalg::LinalgOp>(op);
     if (auto forOp = dyn_cast<AffineForOp>(op)) {
       Node *node = addNodeToMDG(&op, *this, memrefAccesses);
       if (!node)
@@ -277,8 +281,7 @@ bool MemRefDependenceGraph::init() {
       Node *node = addNodeToMDG(&op, *this, memrefAccesses);
       if (!node)
         return false;
-    } else if (!isMemoryEffectFree(&op) &&
-               (op.getNumRegions() == 0 || isa<RegionBranchOpInterface>(op))) {
+    } else if (!isMemoryEffectFree(&op) && !hasUnsupportedRegion) {
       // Create graph node for top-level op unless it is known to be
       // memory-effect free. This covers all unknown/unregistered ops,
       // non-affine ops with memory effects, and region-holding ops with a
@@ -287,7 +290,7 @@ bool MemRefDependenceGraph::init() {
       Node *node = addNodeToMDG(&op, *this, memrefAccesses);
       if (!node)
         return false;
-    } else if (op.getNumRegions() != 0 && !isa<RegionBranchOpInterface>(op)) {
+    } else if (hasUnsupportedRegion) {
       // Return false if non-handled/unknown region-holding ops are found. We
       // won't know what such ops do or what its regions mean; for e.g., it may
       // not be an imperative op.
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index b5b951ad5eb0e..69c7c250404e1 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -548,6 +548,36 @@ func.func @sibling_reduction(%input : memref<10xf32>, %output : memref<10xf32>,
 
 // -----
 
+// Check that presence of a Linalg operator in a block does not prevent
+// fusion from happening in this block.
+
+// ALL-LABEL: func @fusion_in_block_containing_linalg
+func.func @fusion_in_block_containing_linalg(%arg0: memref<5xi8>, %arg1: memref<5xi8>) {
+  %c15_i8 = arith.constant 15 : i8
+  %alloc = memref.alloc() : memref<5xi8>
+  affine.for %arg3 = 0 to 5 {
+    affine.store %c15_i8, %alloc[%arg3] : memref<5xi8>
+  }
+  affine.for %arg3 = 0 to 5 {
+    %0 = affine.load %alloc[%arg3] : memref<5xi8>
+    %1 = affine.load %arg0[%arg3] : memref<5xi8>
+    %2 = arith.muli %0, %1 : i8
+    affine.store %2, %alloc[%arg3] : memref<5xi8>
+  }
+  // ALL:       affine.for
+  // ALL-NEXT:    affine.store
+  // ALL-NEXT:    affine.load
+  // ALL-NEXT:    affine.load
+  // ALL-NEXT:    arith.muli
+  // ALL-NEXT:    affine.store
+  // ALL-NEXT:  }
+  linalg.elemwise_binary ins(%alloc, %alloc: memref<5xi8>, memref<5xi8>) outs(%arg1: memref<5xi8>)
+  // ALL-NEXT:  linalg.elemwise_binary
+  return
+}
+
+// -----
+
 // From  https://github.com/llvm/llvm-project/issues/54541
 
 #map = affine_map<(d0) -> (d0 mod 65536)>



More information about the Mlir-commits mailing list