[Mlir-commits] [mlir] [MLIR][Linalg] Use Top-Down traversal to safely optimize multi-use producer fusion (PR #172216)
Miloš Poletanović
llvmlistbot at llvm.org
Thu Dec 18 15:49:01 PST 2025
https://github.com/milos1397 updated https://github.com/llvm/llvm-project/pull/172216
>From e605dabf6d51a65cad92224b04a64420df4c0211 Mon Sep 17 00:00:00 2001
From: Milos Poletanovic <mpoletanovic at syrmia.com>
Date: Sun, 14 Dec 2025 16:15:20 +0100
Subject: [PATCH 1/3] [MLIR][Linalg] Use Top-Down traversal to safely optimize
multi-use producer fusion
Switches the greedy rewrite traversal for the multi-use producer fusion
pattern to Top-Down (Pre-Order).
---
.../Linalg/fusion-multiuse-producer.mlir | 71 +++++++++++++++++++
.../Linalg/TestLinalgElementwiseFusion.cpp | 3 +-
2 files changed, 73 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
index 7871ae08fd54a..96845448dd1c2 100644
--- a/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
@@ -32,3 +32,74 @@ func.func @multi_use_producer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
// CHECK: %[[RESULT:.+]]:3 = linalg.generic
// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1, %[[RESULT]]#2
+
+func.func @multi_use_producer_2(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> attributes {llvm.emit_c_interface} {
+ %0 = llvm.mlir.constant(0.000000e+00 : f32) : f32
+ %1 = llvm.mlir.constant(31 : index) : i64
+ %2 = tensor.empty() : tensor<1x32x32x8xf32>
+ %3 = tensor.empty() : tensor<1x32x32x8xindex>
+ %4:2 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%arg0 : tensor<1x32x32x8xf32>)
+ outs(%2, %3 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) {
+ ^bb0(%in: f32, %out: f32, %out_0: index):
+ %9 = linalg.index 1 : index
+ linalg.yield %0, %9 : f32, index
+ } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>)
+
+ %5 = tensor.empty() : tensor<1x32x32x8xi64>
+ %6:2 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%arg0, %4#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>)
+ outs(%2, %5 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
+ ^bb0(%in: f32, %in_0: index, %out: f32, %out_1: i64):
+ %9 = builtin.unrealized_conversion_cast %in_0 : index to i64
+ linalg.yield %0, %9 : f32, i64
+ } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+
+ %7 = tensor.empty() : tensor<1x32x32x8xi64>
+ %8:2 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%arg0, %4#1, %6#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>, tensor<1x32x32x8xi64>)
+ outs(%2, %7 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
+ ^bb0(%in: f32, %in_0: index, %in_1: i64, %out: f32, %out_2: i64):
+ %9 = llvm.sub %1, %in_1 : i64
+ linalg.yield %0, %9 : f32, i64
+ } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+ return %8#0 : tensor<1x32x32x8xf32>
+}
+// CHECK-LABEL: func @multi_use_producer_2(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x32x32x8xf32>)
+// CHECK-SAME: -> tensor<1x32x32x8xf32>
+// CHECK: %[[C31:.+]] = llvm.mlir.constant(31 : index) : i64
+// CHECK: %[[R0:.+]]:2 = linalg.generic {
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG0]], %[[ARG0]], %[[ARG0]], %[[ARG0]] : tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>)
+// CHECK-SAME: outs(%[[INIT:.+]], %[[INIT_1:.+]] : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[IN_2:.+]]: f32, %[[IN_3:.+]]: f32, %[[IN_4:.+]]: f32, %[[OUT:.+]]: f32, %[[OUT_I:.+]]: i64):
+// CHECK: %[[IDX_9:.+]] = linalg.index 1 : index
+// CHECK: %[[C_9:.+]] = builtin.unrealized_conversion_cast %[[IDX_9]] : index to i64
+// CHECK: %[[C_SUB:.+]] = llvm.sub %[[C31]], %[[C_9]] : i64
+// CHECK: linalg.yield %[[C0:.+]], %[[C_SUB]] : f32, i64
+// CHECK: } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+// CHECK: return %[[R0]]#0 : tensor<1x32x32x8xf32>
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index cb215197253bb..f8f2184b0de6f 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -252,7 +252,8 @@ struct TestLinalgElementwiseFusion
if (fuseMultiUseProducer) {
RewritePatternSet patterns(context);
patterns.insert<TestMultiUseProducerFusion>(context);
- if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
+ if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns),
+ GreedyRewriteConfig().setUseTopDownTraversal(true))))
return signalPassFailure();
return;
}
>From 0269bdf1ec7efb18bba6849af69297069eef3e59 Mon Sep 17 00:00:00 2001
From: Milos Poletanovic <mpoletanovic at syrmia.com>
Date: Wed, 17 Dec 2025 13:37:43 +0100
Subject: [PATCH 2/3] Format code.
---
mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index f8f2184b0de6f..81f4c881bacc8 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -252,8 +252,9 @@ struct TestLinalgElementwiseFusion
if (fuseMultiUseProducer) {
RewritePatternSet patterns(context);
patterns.insert<TestMultiUseProducerFusion>(context);
- if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns),
- GreedyRewriteConfig().setUseTopDownTraversal(true))))
+ if (failed(applyPatternsGreedily(
+ funcOp.getBody(), std::move(patterns),
+ GreedyRewriteConfig().setUseTopDownTraversal(true))))
return signalPassFailure();
return;
}
>From 700dd5c61a9843cb30500122c5e2b5045eb435b1 Mon Sep 17 00:00:00 2001
From: Milos Poletanovic <milos.poletanovic at htecgroup.com>
Date: Fri, 19 Dec 2025 00:47:13 +0100
Subject: [PATCH 3/3] Addressed comments.
---
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 3 +++
mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir | 2 +-
2 files changed, 4 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d00183a1e16a1..3dc68086e442b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -561,6 +561,9 @@ struct ElementwiseOpFusionResult {
Operation *fusedOp;
llvm::DenseMap<Value, Value> replacements;
};
+/// This transformation is intended to be used with a top-down traversal
+/// (from producer to consumer). In that way fusion logic can safely handle
+/// producers with multiple users.
FailureOr<ElementwiseOpFusionResult>
fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
diff --git a/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
index 96845448dd1c2..ca788839a4927 100644
--- a/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
@@ -102,4 +102,4 @@ func.func @multi_use_producer_2(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x
// CHECK: %[[C_SUB:.+]] = llvm.sub %[[C31]], %[[C_9]] : i64
// CHECK: linalg.yield %[[C0:.+]], %[[C_SUB]] : f32, i64
// CHECK: } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
-// CHECK: return %[[R0]]#0 : tensor<1x32x32x8xf32>
\ No newline at end of file
+// CHECK: return %[[R0]]#0 : tensor<1x32x32x8xf32>
More information about the Mlir-commits
mailing list