[Mlir-commits] [mlir] [MLIR][Linalg] Use Top-Down traversal to safely optimize multi-use producer fusion (PR #172216)

Miloš Poletanović llvmlistbot at llvm.org
Sun Dec 14 07:40:26 PST 2025


https://github.com/milos1397 created https://github.com/llvm/llvm-project/pull/172216

Switches the greedy rewrite traversal for the multi-use producer fusion pattern to Top-Down (Pre-Order).

The previous Bottom-Up (Post-Order) traversal led to a critical SSA violation when a producer (P) had multiple users (I and C) and the first user (I) appeared before the current consumer (C) in the block. Processing the outer consumer (C) first and attempting to fuse P into C would create a new fused operation, F. The rewrite would attempt to replace P's result (used by I) with the output of F. However, since I is located before F in the block, this replacement breaks SSA dominance rules, leading to a crash. To ensure correctness, the first use (I) must be processed and fused before the second use (C). Using Top-Down traversal ensures that operations are visited and rewritten in the correct flow order.

Take a look at this example, which represents a three-operation chain where the first operation, P (**%13:2**), has two users: an intermediate operation I (**%15:2**) and a final consumer C (**%17:2**):
```
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
  func.func @avgpool2d_pad_top(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> attributes {llvm.emit_c_interface} {
    %0 = llvm.mlir.constant(0.000000e+00 : f32) : f32
    %1 = llvm.mlir.constant(31 : index) : i64
    %11 = tensor.empty() : tensor<1x32x32x8xf32>
    %12 = tensor.empty() : tensor<1x32x32x8xindex>
    %13:2 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x32x32x8xf32>) outs(%11, %12 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) {
    ^bb0(%in: f32, %out: f32, %out_0: index):
      %59 = linalg.index 1 : index
      linalg.yield %0, %59 : f32, index
    } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>)
    %14 = tensor.empty() : tensor<1x32x32x8xi64>
    %15:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %13#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) outs(%11, %14 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
    ^bb0(%in: f32, %in_0: index, %out: f32, %out_1: i64):
      %59 = builtin.unrealized_conversion_cast %in_0 : index to i64
      linalg.yield %0, %59 : f32, i64
    } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
    %16 = tensor.empty() : tensor<1x32x32x8xi64>
    %17:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %13#1, %15#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>, tensor<1x32x32x8xi64>) outs(%11, %16 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
    ^bb0(%in: f32, %in_0: index, %in_1: i64, %out: f32, %out_2: i64):
      %59 = llvm.sub %1, %in_1 : i64
      linalg.yield %0, %59 : f32, i64
    } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
    return %17 : tensor<1x32x32x8xf32>
  }
}
```
If fused op is inserted at the position of **%17**, the rewrite mechanism must update all users of P's result (**%13**). Since the intermediate user I (**%15**) is before the final consumer C (**%17**) in the block, renaming I's operand (which is **%13**) to the output of the new fused operation results in a violation of SSA dominance, causing the compiler to crash.

Issue: [#131446](https://github.com/llvm/llvm-project/issues/131446)

>From e605dabf6d51a65cad92224b04a64420df4c0211 Mon Sep 17 00:00:00 2001
From: Milos Poletanovic <mpoletanovic at syrmia.com>
Date: Sun, 14 Dec 2025 16:15:20 +0100
Subject: [PATCH] [MLIR][Linalg] Use Top-Down traversal to safely optimize
 multi-use producer fusion

Switches the greedy rewrite traversal for the multi-use producer fusion
pattern to Top-Down (Pre-Order).
---
 .../Linalg/fusion-multiuse-producer.mlir      | 71 +++++++++++++++++++
 .../Linalg/TestLinalgElementwiseFusion.cpp    |  3 +-
 2 files changed, 73 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
index 7871ae08fd54a..96845448dd1c2 100644
--- a/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
@@ -32,3 +32,74 @@ func.func @multi_use_producer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
 // CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
 //      CHECK:   %[[RESULT:.+]]:3 = linalg.generic
 //      CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1, %[[RESULT]]#2
+
+func.func @multi_use_producer_2(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> attributes {llvm.emit_c_interface} {
+  %0 = llvm.mlir.constant(0.000000e+00 : f32) : f32
+  %1 = llvm.mlir.constant(31 : index) : i64
+  %2 = tensor.empty() : tensor<1x32x32x8xf32>
+  %3 = tensor.empty() : tensor<1x32x32x8xindex>
+  %4:2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ], 
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } 
+  ins(%arg0 : tensor<1x32x32x8xf32>) 
+  outs(%2, %3 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) {
+    ^bb0(%in: f32, %out: f32, %out_0: index):
+      %9 = linalg.index 1 : index
+      linalg.yield %0, %9 : f32, index
+  } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>)
+
+  %5 = tensor.empty() : tensor<1x32x32x8xi64>
+  %6:2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ], 
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } 
+  ins(%arg0, %4#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) 
+  outs(%2, %5 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
+    ^bb0(%in: f32, %in_0: index, %out: f32, %out_1: i64):
+      %9 = builtin.unrealized_conversion_cast %in_0 : index to i64
+      linalg.yield %0, %9 : f32, i64
+  } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+
+  %7 = tensor.empty() : tensor<1x32x32x8xi64>
+  %8:2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ], 
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } 
+  ins(%arg0, %4#1, %6#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>, tensor<1x32x32x8xi64>) 
+  outs(%2, %7 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
+    ^bb0(%in: f32, %in_0: index, %in_1: i64, %out: f32, %out_2: i64):
+      %9 = llvm.sub %1, %in_1 : i64
+      linalg.yield %0, %9 : f32, i64
+  } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+  return %8#0 : tensor<1x32x32x8xf32>
+}
+// CHECK-LABEL: func @multi_use_producer_2(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x32x32x8xf32>)
+// CHECK-SAME: -> tensor<1x32x32x8xf32>
+// CHECK: %[[C31:.+]] = llvm.mlir.constant(31 : index) : i64
+// CHECK: %[[R0:.+]]:2 = linalg.generic {
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG0]], %[[ARG0]], %[[ARG0]], %[[ARG0]] : tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>)
+// CHECK-SAME: outs(%[[INIT:.+]], %[[INIT_1:.+]] : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[IN_2:.+]]: f32, %[[IN_3:.+]]: f32, %[[IN_4:.+]]: f32, %[[OUT:.+]]: f32, %[[OUT_I:.+]]: i64):
+// CHECK: %[[IDX_9:.+]] = linalg.index 1 : index
+// CHECK: %[[C_9:.+]] = builtin.unrealized_conversion_cast %[[IDX_9]] : index to i64
+// CHECK: %[[C_SUB:.+]] = llvm.sub %[[C31]], %[[C_9]] : i64
+// CHECK: linalg.yield %[[C0:.+]], %[[C_SUB]] : f32, i64
+// CHECK: } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+// CHECK: return %[[R0]]#0 : tensor<1x32x32x8xf32>
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index cb215197253bb..f8f2184b0de6f 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -252,7 +252,8 @@ struct TestLinalgElementwiseFusion
     if (fuseMultiUseProducer) {
       RewritePatternSet patterns(context);
       patterns.insert<TestMultiUseProducerFusion>(context);
-      if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
+      if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns),
+                GreedyRewriteConfig().setUseTopDownTraversal(true))))
         return signalPassFailure();
       return;
     }



More information about the Mlir-commits mailing list