[Mlir-commits] [mlir] [mlir] Allow blacklist ops for reduction linalg elementwise fusion (PR #144176)

Evan Liu llvmlistbot at llvm.org
Fri Jun 13 18:12:02 PDT 2025


https://github.com/Evanyl created https://github.com/llvm/llvm-project/pull/144176

We should allow the user to blacklist certain linalg elementwise fusion patterns as they may be expensive on their hardware.

>From 7d8d2d8bfd338eb09d71cb2578ffd7166f8c3c8d Mon Sep 17 00:00:00 2001
From: Evan Liu <liuyievan at gmail.com>
Date: Fri, 13 Jun 2025 17:59:57 -0700
Subject: [PATCH] [mlir] Allow blacklist ops for reduction linalg elementwise
 fusion

---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |  3 +
 .../Dialect/Linalg/Transforms/Transforms.h    |  9 ++-
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 62 +++++++++++++++----
 .../Linalg/fusion-elementwise-blacklist.mlir  | 49 +++++++++++++++
 .../Linalg/TestLinalgElementwiseFusion.cpp    | 25 +++++++-
 5 files changed, 132 insertions(+), 16 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 373842c9b03de..5db234770ef5c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -72,6 +72,9 @@ def LinalgFoldUnitExtentDimsPass : Pass<"linalg-fold-unit-extent-dims", ""> {
 
 def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> {
   let summary = "Fuse elementwise operations on tensors";
+  let options = [ListOption<"reductionFusionOpBlacklist",
+                            "reduction-fusion-blacklist", "std::string",
+                            "List of ops to blacklist for reduction fusion.">];
   let dependentDialects = [
     "affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
   ];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..e4968930ce554 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -450,7 +450,8 @@ using ControlSplitReductionFn =
 /// Return true if two `linalg.generic` operations with producer/consumer
 /// relationship through `fusedOperand` can be fused using elementwise op
 /// fusion.
-bool areElementwiseOpsFusable(OpOperand *fusedOperand);
+bool areElementwiseOpsFusable(OpOperand *fusedOperand,
+                              llvm::StringSet<> &blacklistedReductionFusionOps);
 
 /// Promote memref.subviews feeding linalg-on-buffers operations.
 LogicalResult promoteSubviewsPrecondition(Operation *op,
@@ -505,7 +506,8 @@ struct ElementwiseOpFusionResult {
   llvm::DenseMap<Value, Value> replacements;
 };
 FailureOr<ElementwiseOpFusionResult>
-fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
+fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand,
+                   llvm::StringSet<> &blacklistedReductionFusionOps);
 
 /// Returns a set of indices of the producer's results which would
 /// be preserved after the fusion.
@@ -1783,7 +1785,8 @@ using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;
 /// when both operations are fusable elementwise operations.
 void populateElementwiseOpsFusionPatterns(
     RewritePatternSet &patterns,
-    const ControlFusionFn &controlElementwiseOpFusion);
+    const ControlFusionFn &controlElementwiseOpFusion,
+    llvm::StringSet<> *blacklistedReductionFusionOps = nullptr);
 
 /// Function type which is used to control propagation of linalg.pack/unpack
 /// ops.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f97ed3d6d5111..e10bba04951b7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -104,6 +104,20 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
              indexingMaps, producer.getContext())) != AffineMap();
 }
 
+static bool
+shouldFuseIntoReduction(linalg::GenericOp op,
+                        llvm::StringSet<> &blacklistedReductionFusionOps) {
+  for (Operation &innerOp : op.getRegion().front()) {
+    if (innerOp.hasTrait<OpTrait::IsTerminator>())
+      continue;
+
+    if (blacklistedReductionFusionOps.contains(
+            innerOp.getName().getStringRef()))
+      return false;
+  }
+  return true;
+}
+
 /// Returns a set of indices of the producer's results which would
 /// be preserved after the fusion.
 /// * There is a chance that the implementation of the transformation does not
@@ -136,7 +150,8 @@ llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
 }
 
 /// Conditions for elementwise fusion of generic operations.
-bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
+bool mlir::linalg::areElementwiseOpsFusable(
+    OpOperand *fusedOperand, llvm::StringSet<> &blacklistedReductionFusionOps) {
   if (!fusedOperand)
     return false;
 
@@ -159,6 +174,10 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
   if (producer.getNumParallelLoops() != producer.getNumLoops())
     return false;
 
+  if (consumer.getNumReductionLoops() > 0 &&
+      !shouldFuseIntoReduction(producer, blacklistedReductionFusionOps))
+    return false;
+
   // Only allow fusing the producer of an input operand for now.
   // TODO: allow fusing the producer of an output operand.
   if (!consumer.isDpsInput(fusedOperand))
@@ -335,10 +354,12 @@ static void generateFusedElementwiseOpRegion(
 }
 
 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
-mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
-                                 OpOperand *fusedOperand) {
-  assert(areElementwiseOpsFusable(fusedOperand) &&
-         "expected elementwise operation pre-conditions to pass");
+mlir::linalg::fuseElementwiseOps(
+    RewriterBase &rewriter, OpOperand *fusedOperand,
+    llvm::StringSet<> &blacklistedReductionFusionOps) {
+  assert(
+      areElementwiseOpsFusable(fusedOperand, blacklistedReductionFusionOps) &&
+      "expected elementwise operation pre-conditions to pass");
   auto producerResult = cast<OpResult>(fusedOperand->get());
   auto producer = cast<GenericOp>(producerResult.getOwner());
   auto consumer = cast<GenericOp>(fusedOperand->getOwner());
@@ -462,16 +483,19 @@ namespace {
 /// Patterns to fuse a generic op, with the producer of its operands.
 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
 public:
+  llvm::StringSet<> &blacklistedReductionFusionOps;
   FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
+                     llvm::StringSet<> &blacklistedReductionFusionOps,
                      PatternBenefit benefit = 1)
       : OpRewritePattern<GenericOp>(context, benefit),
+        blacklistedReductionFusionOps(blacklistedReductionFusionOps),
         controlFn(std::move(fun)) {}
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
     // Find the first operand that is defined by another generic op on tensors.
     for (OpOperand &opOperand : genericOp->getOpOperands()) {
-      if (!areElementwiseOpsFusable(&opOperand))
+      if (!areElementwiseOpsFusable(&opOperand, blacklistedReductionFusionOps))
         continue;
       if (!controlFn(&opOperand))
         continue;
@@ -479,8 +503,8 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
       Operation *producer = opOperand.get().getDefiningOp();
 
       // Find the producer of the operand.
-      FailureOr<ElementwiseOpFusionResult> fusionResult =
-          fuseElementwiseOps(rewriter, &opOperand);
+      FailureOr<ElementwiseOpFusionResult> fusionResult = fuseElementwiseOps(
+          rewriter, &opOperand, blacklistedReductionFusionOps);
       if (failed(fusionResult))
         return rewriter.notifyMatchFailure(genericOp, "fusion failed");
 
@@ -2248,9 +2272,17 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
 
 void mlir::linalg::populateElementwiseOpsFusionPatterns(
     RewritePatternSet &patterns,
-    const ControlFusionFn &controlElementwiseOpsFusion) {
+    const ControlFusionFn &controlElementwiseOpsFusion,
+    llvm::StringSet<> *blacklistedReductionFusionOps) {
   auto *context = patterns.getContext();
-  patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
+  if (blacklistedReductionFusionOps)
+    patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion,
+                                     *blacklistedReductionFusionOps);
+  else {
+    llvm::StringSet<> emptyBlacklistedReductionFusionOps;
+    patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion,
+                                     emptyBlacklistedReductionFusionOps);
+  }
   patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
                RemoveOutsDependency>(context);
   // Add the patterns that clean up dead operands and results.
@@ -2282,11 +2314,18 @@ struct LinalgElementwiseOpFusionPass
           LinalgElementwiseOpFusionPass> {
   using impl::LinalgElementwiseOpFusionPassBase<
       LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
+
+  llvm::StringSet<> blacklistedReductionFusionOps;
+
   void runOnOperation() override {
     Operation *op = getOperation();
     MLIRContext *context = op->getContext();
     RewritePatternSet patterns(context);
 
+    for (const auto &opName : reductionFusionOpBlacklist) {
+      blacklistedReductionFusionOps.insert(opName);
+    }
+
     // Add folding with reshape by expansion patterns.
     ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
       Operation *producer = fusedOperand->get().getDefiningOp();
@@ -2294,7 +2333,8 @@ struct LinalgElementwiseOpFusionPass
     };
 
     // Add elementwise op fusion patterns.
-    populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
+    populateElementwiseOpsFusionPatterns(patterns, defaultControlFn,
+                                         &blacklistedReductionFusionOps);
     populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
     tensor::populateBubbleUpExpandShapePatterns(patterns);
 
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir
new file mode 100644
index 0000000000000..222c73b7695ee
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=blacklist-ops-for-reduction -split-input-file | FileCheck %s
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0) -> (0, d0)>
+#map2 = affine_map<(d0) -> (0)>
+func.func @consumer_with_reduction_blacklist(%arg0: tensor<1x10xf32>,
+                              %arg1: tensor<1x10xf32>,
+                              %arg2: tensor<1xf32>) -> tensor<1xf32> {
+  %init = tensor.empty() : tensor<1x10xf32>
+  %0 = linalg.generic
+    {indexing_maps = [#map0, #map0, #map0],
+     iterator_types = ["parallel", "parallel"]}
+    ins(%arg0, %arg1 : tensor<1x10xf32>, tensor<1x10xf32>)
+    outs(%init : tensor<1x10xf32>) {
+  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+    %2 = arith.addf %arg3, %arg4 : f32
+    linalg.yield %2 : f32
+  } -> tensor<1x10xf32>
+  %1 = linalg.generic
+    {indexing_maps = [#map1, #map2],
+     iterator_types = ["reduction"]}
+    ins(%0 : tensor<1x10xf32>)
+    outs(%arg2 : tensor<1xf32>)  {
+  ^bb0(%arg3: f32, %arg4: f32):
+    %2 = arith.addf %arg3, %arg4 : f32
+    linalg.yield %2 : f32
+  } -> tensor<1xf32>
+  return %1 : tensor<1xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (0, d0)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (0)>
+//      CHECK: func @consumer_with_reduction_blacklist(%[[ARG0:.+]]: tensor<1x10xf32>, %[[ARG1:.+]]: tensor<1x10xf32>, %[[ARG2:.+]]: tensor<1xf32>)
+//      CHECK:   %[[RES0:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME:     iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]] : tensor<1x10xf32>, tensor<1x10xf32>)
+//      CHECK:   ^{{.+}}(%[[T0:.+]]: f32, %[[T1:.+]]: f32, %[[T2:.+]]: f32)
+//      CHECK:     %[[T3:.+]] = arith.addf %[[T0]], %[[T1]] : f32
+//      CHECK:     linalg.yield %[[T3]]
+//      CHECK:   %[[RES1:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:     iterator_types = ["reduction"]
+// CHECK-SAME:     ins(%[[RES0]] : tensor<1x10xf32>)
+//      CHECK:   ^{{.+}}(%[[T0:.+]]: f32, %[[T1:.+]]: f32)
+//      CHECK:     %[[T2:.+]] = arith.addf %[[T0]], %[[T1]] : f32
+//      CHECK:     linalg.yield %[[T2]]
+//      CHECK:   return %[[RES1]]
+
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index cb215197253bb..801d72c6c9eac 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -60,8 +60,9 @@ struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
   LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *fusableOperand = nullptr;
+    llvm::StringSet<> blacklist;
     for (OpOperand &operand : genericOp->getOpOperands()) {
-      if (linalg::areElementwiseOpsFusable(&operand)) {
+      if (linalg::areElementwiseOpsFusable(&operand, blacklist)) {
         fusableOperand = &operand;
         break;
       }
@@ -70,7 +71,7 @@ struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
       return rewriter.notifyMatchFailure(genericOp, "no fusable operand found");
     }
     std::optional<linalg::ElementwiseOpFusionResult> fusionResult =
-        linalg::fuseElementwiseOps(rewriter, fusableOperand);
+        linalg::fuseElementwiseOps(rewriter, fusableOperand, blacklist);
     if (!fusionResult)
       return rewriter.notifyMatchFailure(genericOp, "fusion failed");
     for (auto [origValue, replacement] : fusionResult->replacements) {
@@ -143,6 +144,12 @@ struct TestLinalgElementwiseFusion
       llvm::cl::desc("Test fusion of producer ops with multiple uses"),
       llvm::cl::init(false)};
 
+  Option<bool> blacklistOpsForReduction{
+      *this, "blacklist-ops-for-reduction",
+      llvm::cl::desc(
+          "Test fusion of generic operations with a control function."),
+      llvm::cl::init(false)};
+
   ListOption<int64_t> collapseDimensions{
       *this, "collapse-dimensions-control",
       llvm::cl::desc("Test controlling dimension collapse pattern")};
@@ -257,6 +264,20 @@ struct TestLinalgElementwiseFusion
       return;
     }
 
+    if (blacklistOpsForReduction) {
+      RewritePatternSet fusionPatterns(context);
+      auto controlFn = [](OpOperand *operand) { return true; };
+      llvm::StringSet<> blacklist;
+      blacklist.insert("arith.addf");
+
+      linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn,
+                                                   &blacklist);
+      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                              std::move(fusionPatterns))))
+        return signalPassFailure();
+      return;
+    }
+
     if (!collapseDimensions.empty()) {
       SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
                                    collapseDimensions.end());



More information about the Mlir-commits mailing list