[Mlir-commits] [mlir] [MLIR][Linalg] Use Top-Down traversal to safely optimize multi-use producer fusion (PR #172216)

Miloš Poletanović llvmlistbot at llvm.org
Thu Dec 18 15:49:01 PST 2025


https://github.com/milos1397 updated https://github.com/llvm/llvm-project/pull/172216

>From e605dabf6d51a65cad92224b04a64420df4c0211 Mon Sep 17 00:00:00 2001
From: Milos Poletanovic <mpoletanovic at syrmia.com>
Date: Sun, 14 Dec 2025 16:15:20 +0100
Subject: [PATCH 1/3] [MLIR][Linalg] Use Top-Down traversal to safely optimize
 multi-use producer fusion

Switches the greedy rewrite traversal for the multi-use producer fusion
pattern to Top-Down (Pre-Order).
---
 .../Linalg/fusion-multiuse-producer.mlir      | 71 +++++++++++++++++++
 .../Linalg/TestLinalgElementwiseFusion.cpp    |  3 +-
 2 files changed, 73 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
index 7871ae08fd54a..96845448dd1c2 100644
--- a/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
@@ -32,3 +32,74 @@ func.func @multi_use_producer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
 // CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
 //      CHECK:   %[[RESULT:.+]]:3 = linalg.generic
 //      CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1, %[[RESULT]]#2
+
+func.func @multi_use_producer_2(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> attributes {llvm.emit_c_interface} {
+  %0 = llvm.mlir.constant(0.000000e+00 : f32) : f32
+  %1 = llvm.mlir.constant(31 : index) : i64
+  %2 = tensor.empty() : tensor<1x32x32x8xf32>
+  %3 = tensor.empty() : tensor<1x32x32x8xindex>
+  %4:2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ], 
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } 
+  ins(%arg0 : tensor<1x32x32x8xf32>) 
+  outs(%2, %3 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) {
+    ^bb0(%in: f32, %out: f32, %out_0: index):
+      %9 = linalg.index 1 : index
+      linalg.yield %0, %9 : f32, index
+  } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>)
+
+  %5 = tensor.empty() : tensor<1x32x32x8xi64>
+  %6:2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ], 
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } 
+  ins(%arg0, %4#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) 
+  outs(%2, %5 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
+    ^bb0(%in: f32, %in_0: index, %out: f32, %out_1: i64):
+      %9 = builtin.unrealized_conversion_cast %in_0 : index to i64
+      linalg.yield %0, %9 : f32, i64
+  } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+
+  %7 = tensor.empty() : tensor<1x32x32x8xi64>
+  %8:2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ], 
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } 
+  ins(%arg0, %4#1, %6#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>, tensor<1x32x32x8xi64>) 
+  outs(%2, %7 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
+    ^bb0(%in: f32, %in_0: index, %in_1: i64, %out: f32, %out_2: i64):
+      %9 = llvm.sub %1, %in_1 : i64
+      linalg.yield %0, %9 : f32, i64
+  } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+  return %8#0 : tensor<1x32x32x8xf32>
+}
+// CHECK-LABEL: func @multi_use_producer_2(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x32x32x8xf32>)
+// CHECK-SAME: -> tensor<1x32x32x8xf32>
+// CHECK: %[[C31:.+]] = llvm.mlir.constant(31 : index) : i64
+// CHECK: %[[R0:.+]]:2 = linalg.generic {
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG0]], %[[ARG0]], %[[ARG0]], %[[ARG0]] : tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>)
+// CHECK-SAME: outs(%[[INIT:.+]], %[[INIT_1:.+]] : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[IN_2:.+]]: f32, %[[IN_3:.+]]: f32, %[[IN_4:.+]]: f32, %[[OUT:.+]]: f32, %[[OUT_I:.+]]: i64):
+// CHECK: %[[IDX_9:.+]] = linalg.index 1 : index
+// CHECK: %[[C_9:.+]] = builtin.unrealized_conversion_cast %[[IDX_9]] : index to i64
+// CHECK: %[[C_SUB:.+]] = llvm.sub %[[C31]], %[[C_9]] : i64
+// CHECK: linalg.yield %[[C0:.+]], %[[C_SUB]] : f32, i64
+// CHECK: } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+// CHECK: return %[[R0]]#0 : tensor<1x32x32x8xf32>
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index cb215197253bb..f8f2184b0de6f 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -252,7 +252,8 @@ struct TestLinalgElementwiseFusion
     if (fuseMultiUseProducer) {
       RewritePatternSet patterns(context);
       patterns.insert<TestMultiUseProducerFusion>(context);
-      if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
+      if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns),
+                GreedyRewriteConfig().setUseTopDownTraversal(true))))
         return signalPassFailure();
       return;
     }

>From 0269bdf1ec7efb18bba6849af69297069eef3e59 Mon Sep 17 00:00:00 2001
From: Milos Poletanovic <mpoletanovic at syrmia.com>
Date: Wed, 17 Dec 2025 13:37:43 +0100
Subject: [PATCH 2/3] Format code.

---
 mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index f8f2184b0de6f..81f4c881bacc8 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -252,8 +252,9 @@ struct TestLinalgElementwiseFusion
     if (fuseMultiUseProducer) {
       RewritePatternSet patterns(context);
       patterns.insert<TestMultiUseProducerFusion>(context);
-      if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns),
-                GreedyRewriteConfig().setUseTopDownTraversal(true))))
+      if (failed(applyPatternsGreedily(
+              funcOp.getBody(), std::move(patterns),
+              GreedyRewriteConfig().setUseTopDownTraversal(true))))
         return signalPassFailure();
       return;
     }

>From 700dd5c61a9843cb30500122c5e2b5045eb435b1 Mon Sep 17 00:00:00 2001
From: Milos Poletanovic <milos.poletanovic at htecgroup.com>
Date: Fri, 19 Dec 2025 00:47:13 +0100
Subject: [PATCH 3/3] Addressed comments.

---
 mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 3 +++
 mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir   | 2 +-
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d00183a1e16a1..3dc68086e442b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -561,6 +561,9 @@ struct ElementwiseOpFusionResult {
   Operation *fusedOp;
   llvm::DenseMap<Value, Value> replacements;
 };
+/// This transformation is intended to be used with a top-down traversal
+/// (from producer to consumer). In that way fusion logic can safely handle
+/// producers with multiple users.
 FailureOr<ElementwiseOpFusionResult>
 fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
 
diff --git a/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
index 96845448dd1c2..ca788839a4927 100644
--- a/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
@@ -102,4 +102,4 @@ func.func @multi_use_producer_2(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x
 // CHECK: %[[C_SUB:.+]] = llvm.sub %[[C31]], %[[C_9]] : i64
 // CHECK: linalg.yield %[[C0:.+]], %[[C_SUB]] : f32, i64
 // CHECK: } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
-// CHECK: return %[[R0]]#0 : tensor<1x32x32x8xf32>
\ No newline at end of file
+// CHECK: return %[[R0]]#0 : tensor<1x32x32x8xf32>



More information about the Mlir-commits mailing list