[Mlir-commits] [mlir] [MLIR][linalg] Erase effectful producer after fusion if it's no longer used (PR #170036)

Artemiy Bulavin llvmlistbot at llvm.org
Sun Nov 30 07:01:30 PST 2025


https://github.com/abulavin updated https://github.com/llvm/llvm-project/pull/170036

>From af8cd231fa067faa3b6ad309b37b72ecaf5e8623 Mon Sep 17 00:00:00 2001
From: Artemiy Bulavin <artemiyb at graphcore.ai>
Date: Sun, 30 Nov 2025 14:30:26 +0000
Subject: [PATCH] Remove effectful producer after fusion if no longer used

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp |  6 ++++
 .../Linalg/fusion-elementwise-ops.mlir        | 34 +++++++++++++++++++
 2 files changed, 40 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 05fc7cbbb90af..b65e5dd35ff3c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -493,6 +493,12 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
         });
       }
       rewriter.eraseOp(genericOp);
+      // If after fusion, the producer no longer has uses, erase it. Usually the
+      // greedy pattern driver takes care of this, however if the producer
+      // contains ops with memory effects it won't be considered trivially dead.
+      if (producer->use_empty())
+        rewriter.eraseOp(producer);
+
       return success();
     }
     return failure();
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 6f1a422324e08..b47aeb2812210 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -1056,3 +1056,37 @@ module {
 // CHECK:         tensor.expand_shape
 // CHECK:         linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]}
 // CHECK-SAME:     ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+
+module {
+  func.func @remove_effectful_producer_after_fusion_if_no_uses(%arg0: bf16, %arg1: memref<2xbf16>, %arg2: memref<2xbf16>) {
+    %c0 = arith.constant 0 : index
+    %0 = tensor.empty() : tensor<2xbf16>
+    %1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%0 : tensor<2xbf16>) {
+    ^bb0(%out: bf16):
+      %2 = memref.atomic_rmw addf %arg0, %arg1[%c0] : (bf16, memref<2xbf16>) -> bf16
+      linalg.yield %2 : bf16
+    } -> tensor<2xbf16>
+  
+    linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} ins(%1 : tensor<2xbf16>) {
+    ^bb0(%in: bf16):
+      memref.store %in, %arg2[%c0] : memref<2xbf16>
+      linalg.yield
+    }
+    return
+  }
+}
+
+// CHECK-LABEL:   func.func @remove_effectful_producer_after_fusion_if_no_uses
+// CHECK-SAME:      %[[ARG0:.*]]: bf16,
+// CHECK-SAME:      %[[ARG1:.*]]: memref<2xbf16>,
+// CHECK-SAME:      %[[ARG2:.*]]: memref<2xbf16>)
+// CHECK-NOT:       memref.atomic_rmw addf %[[ARG0]], %[[ARG1]]
+// CHECK:           linalg.generic
+// CHECK:             %[[ATOMIC_RMW:.*]] = memref.atomic_rmw addf %[[ARG0]], %[[ARG1]]
+// CHECK:             memref.store %[[ATOMIC_RMW]], %[[ARG2]]
+// CHECK:             linalg.yield %[[ATOMIC_RMW]]
+// CHECK:           return



More information about the Mlir-commits mailing list