[Mlir-commits] [mlir] 869b44d - [MLIR][Linalg] Use Top-Down traversal to safely optimize multi-use producer fusion (#172216)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 23 08:28:34 PST 2025


Author: Miloš Poletanović
Date: 2025-12-23T08:28:30-08:00
New Revision: 869b44d20cf4fd64c7859e94a274ad61863de95b

URL: https://github.com/llvm/llvm-project/commit/869b44d20cf4fd64c7859e94a274ad61863de95b
DIFF: https://github.com/llvm/llvm-project/commit/869b44d20cf4fd64c7859e94a274ad61863de95b.diff

LOG: [MLIR][Linalg] Use Top-Down traversal to safely optimize multi-use producer fusion (#172216)

Switches the greedy rewrite traversal for the multi-use producer fusion
pattern to Top-Down (Pre-Order).

The previous Bottom-Up (Post-Order) traversal led to a critical SSA
violation when a producer (P) had multiple users (I and C) and the first
user (I) appeared before the current consumer (C) in the block.
Processing the outer consumer (C) first and attempting to fuse P into C
would create a new fused operation, F. The rewrite would attempt to
replace P's result (used by I) with the output of F. However, since I is
located before F in the block, this replacement breaks SSA dominance
rules, leading to a crash. To ensure correctness, the first use (I) must
be processed and fused before the second use (C). Using Top-Down
traversal ensures that operations are visited and rewritten in the
correct flow order.

Take a look at this example, which represents a three-operation chain
where the first operation, P (**%13:2**), has two users: an intermediate
operation I (**%15:2**) and a final consumer C (**%17:2**):
```
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
  func.func @avgpool2d_pad_top(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> attributes {llvm.emit_c_interface} {
    %0 = llvm.mlir.constant(0.000000e+00 : f32) : f32
    %1 = llvm.mlir.constant(31 : index) : i64
    %11 = tensor.empty() : tensor<1x32x32x8xf32>
    %12 = tensor.empty() : tensor<1x32x32x8xindex>
    %13:2 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x32x32x8xf32>) outs(%11, %12 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) {
    ^bb0(%in: f32, %out: f32, %out_0: index):
      %59 = linalg.index 1 : index
      linalg.yield %0, %59 : f32, index
    } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>)
    %14 = tensor.empty() : tensor<1x32x32x8xi64>
    %15:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %13#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) outs(%11, %14 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
    ^bb0(%in: f32, %in_0: index, %out: f32, %out_1: i64):
      %59 = builtin.unrealized_conversion_cast %in_0 : index to i64
      linalg.yield %0, %59 : f32, i64
    } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
    %16 = tensor.empty() : tensor<1x32x32x8xi64>
    %17:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %13#1, %15#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>, tensor<1x32x32x8xi64>) outs(%11, %16 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
    ^bb0(%in: f32, %in_0: index, %in_1: i64, %out: f32, %out_2: i64):
      %59 = llvm.sub %1, %in_1 : i64
      linalg.yield %0, %59 : f32, i64
    } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
    return %17 : tensor<1x32x32x8xf32>
  }
}
```
If fused op is inserted at the position of **%17**, the rewrite
mechanism must update all users of P's result (**%13**). Since the
intermediate user I (**%15**) is before the final consumer C (**%17**)
in the block, renaming I's operand (which is **%13**) to the output of
the new fused operation results in a violation of SSA dominance, causing
the compiler to crash.

Issue: [#131446](https://github.com/llvm/llvm-project/issues/131446)

---------

Co-authored-by: Milos Poletanovic <mpoletanovic at syrmia.com>
Co-authored-by: Milos Poletanovic <milos.poletanovic at htecgroup.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7258af4833ae5..6678d693719bf 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -657,6 +657,9 @@ struct ElementwiseOpFusionResult {
   Operation *fusedOp;
   llvm::DenseMap<Value, Value> replacements;
 };
+/// This transformation is intended to be used with a top-down traversal
+/// (from producer to consumer). In that way fusion logic can safely handle
+/// producers with multiple users.
 FailureOr<ElementwiseOpFusionResult>
 fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
 

diff  --git a/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
index 7871ae08fd54a..ca788839a4927 100644
--- a/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
@@ -32,3 +32,74 @@ func.func @multi_use_producer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
 // CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
 //      CHECK:   %[[RESULT:.+]]:3 = linalg.generic
 //      CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1, %[[RESULT]]#2
+
+func.func @multi_use_producer_2(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> attributes {llvm.emit_c_interface} {
+  %0 = llvm.mlir.constant(0.000000e+00 : f32) : f32
+  %1 = llvm.mlir.constant(31 : index) : i64
+  %2 = tensor.empty() : tensor<1x32x32x8xf32>
+  %3 = tensor.empty() : tensor<1x32x32x8xindex>
+  %4:2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ], 
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } 
+  ins(%arg0 : tensor<1x32x32x8xf32>) 
+  outs(%2, %3 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) {
+    ^bb0(%in: f32, %out: f32, %out_0: index):
+      %9 = linalg.index 1 : index
+      linalg.yield %0, %9 : f32, index
+  } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>)
+
+  %5 = tensor.empty() : tensor<1x32x32x8xi64>
+  %6:2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ], 
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } 
+  ins(%arg0, %4#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) 
+  outs(%2, %5 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
+    ^bb0(%in: f32, %in_0: index, %out: f32, %out_1: i64):
+      %9 = builtin.unrealized_conversion_cast %in_0 : index to i64
+      linalg.yield %0, %9 : f32, i64
+  } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+
+  %7 = tensor.empty() : tensor<1x32x32x8xi64>
+  %8:2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ], 
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } 
+  ins(%arg0, %4#1, %6#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>, tensor<1x32x32x8xi64>) 
+  outs(%2, %7 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
+    ^bb0(%in: f32, %in_0: index, %in_1: i64, %out: f32, %out_2: i64):
+      %9 = llvm.sub %1, %in_1 : i64
+      linalg.yield %0, %9 : f32, i64
+  } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+  return %8#0 : tensor<1x32x32x8xf32>
+}
+// CHECK-LABEL: func @multi_use_producer_2(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x32x32x8xf32>)
+// CHECK-SAME: -> tensor<1x32x32x8xf32>
+// CHECK: %[[C31:.+]] = llvm.mlir.constant(31 : index) : i64
+// CHECK: %[[R0:.+]]:2 = linalg.generic {
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG0]], %[[ARG0]], %[[ARG0]], %[[ARG0]] : tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>)
+// CHECK-SAME: outs(%[[INIT:.+]], %[[INIT_1:.+]] : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[IN_2:.+]]: f32, %[[IN_3:.+]]: f32, %[[IN_4:.+]]: f32, %[[OUT:.+]]: f32, %[[OUT_I:.+]]: i64):
+// CHECK: %[[IDX_9:.+]] = linalg.index 1 : index
+// CHECK: %[[C_9:.+]] = builtin.unrealized_conversion_cast %[[IDX_9]] : index to i64
+// CHECK: %[[C_SUB:.+]] = llvm.sub %[[C31]], %[[C_9]] : i64
+// CHECK: linalg.yield %[[C0:.+]], %[[C_SUB]] : f32, i64
+// CHECK: } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+// CHECK: return %[[R0]]#0 : tensor<1x32x32x8xf32>

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index cb215197253bb..81f4c881bacc8 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -252,7 +252,9 @@ struct TestLinalgElementwiseFusion
     if (fuseMultiUseProducer) {
       RewritePatternSet patterns(context);
       patterns.insert<TestMultiUseProducerFusion>(context);
-      if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
+      if (failed(applyPatternsGreedily(
+              funcOp.getBody(), std::move(patterns),
+              GreedyRewriteConfig().setUseTopDownTraversal(true))))
         return signalPassFailure();
       return;
     }


        


More information about the Mlir-commits mailing list