[Mlir-commits] [mlir] [mlir] Allow blacklist ops for reduction linalg elementwise fusion (PR #144176)
Evan Liu
llvmlistbot at llvm.org
Fri Jun 13 18:55:10 PDT 2025
https://github.com/Evanyl updated https://github.com/llvm/llvm-project/pull/144176
>From afedeae592376192217cf1faa0033df06c2ce7f0 Mon Sep 17 00:00:00 2001
From: Evan Liu <liuyievan at gmail.com>
Date: Fri, 13 Jun 2025 17:59:57 -0700
Subject: [PATCH] [mlir] Allow blacklist ops for reduction linalg elementwise
fusion
---
mlir/include/mlir/Dialect/Linalg/Passes.td | 3 +
.../Dialect/Linalg/Transforms/Transforms.h | 9 ++-
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 62 +++++++++++++++----
.../Linalg/fusion-elementwise-blacklist.mlir | 49 +++++++++++++++
.../Linalg/TestLinalgElementwiseFusion.cpp | 25 +++++++-
5 files changed, 132 insertions(+), 16 deletions(-)
create mode 100644 mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 373842c9b03de..5db234770ef5c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -72,6 +72,9 @@ def LinalgFoldUnitExtentDimsPass : Pass<"linalg-fold-unit-extent-dims", ""> {
def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> {
let summary = "Fuse elementwise operations on tensors";
+ let options = [ListOption<"reductionFusionOpBlacklist",
+ "reduction-fusion-blacklist", "std::string",
+ "List of ops to blacklist for reduction fusion.">];
let dependentDialects = [
"affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..e4968930ce554 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -450,7 +450,8 @@ using ControlSplitReductionFn =
/// Return true if two `linalg.generic` operations with producer/consumer
/// relationship through `fusedOperand` can be fused using elementwise op
/// fusion.
-bool areElementwiseOpsFusable(OpOperand *fusedOperand);
+bool areElementwiseOpsFusable(OpOperand *fusedOperand,
+ llvm::StringSet<> &blacklistedReductionFusionOps);
/// Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult promoteSubviewsPrecondition(Operation *op,
@@ -505,7 +506,8 @@ struct ElementwiseOpFusionResult {
llvm::DenseMap<Value, Value> replacements;
};
FailureOr<ElementwiseOpFusionResult>
-fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
+fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand,
+ llvm::StringSet<> &blacklistedReductionFusionOps);
/// Returns a set of indices of the producer's results which would
/// be preserved after the fusion.
@@ -1783,7 +1785,8 @@ using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;
/// when both operations are fusable elementwise operations.
void populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
- const ControlFusionFn &controlElementwiseOpFusion);
+ const ControlFusionFn &controlElementwiseOpFusion,
+ llvm::StringSet<> *blacklistedReductionFusionOps = nullptr);
/// Function type which is used to control propagation of linalg.pack/unpack
/// ops.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f97ed3d6d5111..e10bba04951b7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -104,6 +104,20 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
indexingMaps, producer.getContext())) != AffineMap();
}
+static bool
+shouldFuseIntoReduction(linalg::GenericOp op,
+ llvm::StringSet<> &blacklistedReductionFusionOps) {
+ for (Operation &innerOp : op.getRegion().front()) {
+ if (innerOp.hasTrait<OpTrait::IsTerminator>())
+ continue;
+
+ if (blacklistedReductionFusionOps.contains(
+ innerOp.getName().getStringRef()))
+ return false;
+ }
+ return true;
+}
+
/// Returns a set of indices of the producer's results which would
/// be preserved after the fusion.
/// * There is a chance that the implementation of the transformation does not
@@ -136,7 +150,8 @@ llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
}
/// Conditions for elementwise fusion of generic operations.
-bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
+bool mlir::linalg::areElementwiseOpsFusable(
+ OpOperand *fusedOperand, llvm::StringSet<> &blacklistedReductionFusionOps) {
if (!fusedOperand)
return false;
@@ -159,6 +174,10 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
if (producer.getNumParallelLoops() != producer.getNumLoops())
return false;
+ if (consumer.getNumReductionLoops() > 0 &&
+ !shouldFuseIntoReduction(producer, blacklistedReductionFusionOps))
+ return false;
+
// Only allow fusing the producer of an input operand for now.
// TODO: allow fusing the producer of an output operand.
if (!consumer.isDpsInput(fusedOperand))
@@ -335,10 +354,12 @@ static void generateFusedElementwiseOpRegion(
}
FailureOr<mlir::linalg::ElementwiseOpFusionResult>
-mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
- OpOperand *fusedOperand) {
- assert(areElementwiseOpsFusable(fusedOperand) &&
- "expected elementwise operation pre-conditions to pass");
+mlir::linalg::fuseElementwiseOps(
+ RewriterBase &rewriter, OpOperand *fusedOperand,
+ llvm::StringSet<> &blacklistedReductionFusionOps) {
+ assert(
+ areElementwiseOpsFusable(fusedOperand, blacklistedReductionFusionOps) &&
+ "expected elementwise operation pre-conditions to pass");
auto producerResult = cast<OpResult>(fusedOperand->get());
auto producer = cast<GenericOp>(producerResult.getOwner());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
@@ -462,16 +483,19 @@ namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
public:
+ llvm::StringSet<> &blacklistedReductionFusionOps;
FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
+ llvm::StringSet<> &blacklistedReductionFusionOps,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit),
+ blacklistedReductionFusionOps(blacklistedReductionFusionOps),
controlFn(std::move(fun)) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
for (OpOperand &opOperand : genericOp->getOpOperands()) {
- if (!areElementwiseOpsFusable(&opOperand))
+ if (!areElementwiseOpsFusable(&opOperand, blacklistedReductionFusionOps))
continue;
if (!controlFn(&opOperand))
continue;
@@ -479,8 +503,8 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
Operation *producer = opOperand.get().getDefiningOp();
// Find the producer of the operand.
- FailureOr<ElementwiseOpFusionResult> fusionResult =
- fuseElementwiseOps(rewriter, &opOperand);
+ FailureOr<ElementwiseOpFusionResult> fusionResult = fuseElementwiseOps(
+ rewriter, &opOperand, blacklistedReductionFusionOps);
if (failed(fusionResult))
return rewriter.notifyMatchFailure(genericOp, "fusion failed");
@@ -2248,9 +2272,17 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
void mlir::linalg::populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
- const ControlFusionFn &controlElementwiseOpsFusion) {
+ const ControlFusionFn &controlElementwiseOpsFusion,
+ llvm::StringSet<> *blacklistedReductionFusionOps) {
auto *context = patterns.getContext();
- patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
+ if (blacklistedReductionFusionOps)
+ patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion,
+ *blacklistedReductionFusionOps);
+ else {
+ llvm::StringSet<> emptyBlacklistedReductionFusionOps;
+ patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion,
+ emptyBlacklistedReductionFusionOps);
+ }
patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
RemoveOutsDependency>(context);
// Add the patterns that clean up dead operands and results.
@@ -2282,11 +2314,18 @@ struct LinalgElementwiseOpFusionPass
LinalgElementwiseOpFusionPass> {
using impl::LinalgElementwiseOpFusionPassBase<
LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
+
+ llvm::StringSet<> blacklistedReductionFusionOps;
+
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *context = op->getContext();
RewritePatternSet patterns(context);
+ for (const auto &opName : reductionFusionOpBlacklist) {
+ blacklistedReductionFusionOps.insert(opName);
+ }
+
// Add folding with reshape by expansion patterns.
ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
@@ -2294,7 +2333,8 @@ struct LinalgElementwiseOpFusionPass
};
// Add elementwise op fusion patterns.
- populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
+ populateElementwiseOpsFusionPatterns(patterns, defaultControlFn,
+ &blacklistedReductionFusionOps);
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
tensor::populateBubbleUpExpandShapePatterns(patterns);
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir
new file mode 100644
index 0000000000000..222c73b7695ee
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=blacklist-ops-for-reduction -split-input-file | FileCheck %s
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0) -> (0, d0)>
+#map2 = affine_map<(d0) -> (0)>
+func.func @consumer_with_reduction_blacklist(%arg0: tensor<1x10xf32>,
+ %arg1: tensor<1x10xf32>,
+ %arg2: tensor<1xf32>) -> tensor<1xf32> {
+ %init = tensor.empty() : tensor<1x10xf32>
+ %0 = linalg.generic
+ {indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<1x10xf32>, tensor<1x10xf32>)
+ outs(%init : tensor<1x10xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %2 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %2 : f32
+ } -> tensor<1x10xf32>
+ %1 = linalg.generic
+ {indexing_maps = [#map1, #map2],
+ iterator_types = ["reduction"]}
+ ins(%0 : tensor<1x10xf32>)
+ outs(%arg2 : tensor<1xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32):
+ %2 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %2 : f32
+ } -> tensor<1xf32>
+ return %1 : tensor<1xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (0, d0)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (0)>
+// CHECK: func @consumer_with_reduction_blacklist(%[[ARG0:.+]]: tensor<1x10xf32>, %[[ARG1:.+]]: tensor<1x10xf32>, %[[ARG2:.+]]: tensor<1xf32>)
+// CHECK: %[[RES0:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x10xf32>, tensor<1x10xf32>)
+// CHECK: ^{{.+}}(%[[T0:.+]]: f32, %[[T1:.+]]: f32, %[[T2:.+]]: f32)
+// CHECK: %[[T3:.+]] = arith.addf %[[T0]], %[[T1]] : f32
+// CHECK: linalg.yield %[[T3]]
+// CHECK: %[[RES1:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["reduction"]
+// CHECK-SAME: ins(%[[RES0]] : tensor<1x10xf32>)
+// CHECK: ^{{.+}}(%[[T0:.+]]: f32, %[[T1:.+]]: f32)
+// CHECK: %[[T2:.+]] = arith.addf %[[T0]], %[[T1]] : f32
+// CHECK: linalg.yield %[[T2]]
+// CHECK: return %[[RES1]]
+
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index cb215197253bb..801d72c6c9eac 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -60,8 +60,9 @@ struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
OpOperand *fusableOperand = nullptr;
+ llvm::StringSet<> blacklist;
for (OpOperand &operand : genericOp->getOpOperands()) {
- if (linalg::areElementwiseOpsFusable(&operand)) {
+ if (linalg::areElementwiseOpsFusable(&operand, blacklist)) {
fusableOperand = &operand;
break;
}
@@ -70,7 +71,7 @@ struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
return rewriter.notifyMatchFailure(genericOp, "no fusable operand found");
}
std::optional<linalg::ElementwiseOpFusionResult> fusionResult =
- linalg::fuseElementwiseOps(rewriter, fusableOperand);
+ linalg::fuseElementwiseOps(rewriter, fusableOperand, blacklist);
if (!fusionResult)
return rewriter.notifyMatchFailure(genericOp, "fusion failed");
for (auto [origValue, replacement] : fusionResult->replacements) {
@@ -143,6 +144,12 @@ struct TestLinalgElementwiseFusion
llvm::cl::desc("Test fusion of producer ops with multiple uses"),
llvm::cl::init(false)};
+ Option<bool> blacklistOpsForReduction{
+ *this, "blacklist-ops-for-reduction",
+ llvm::cl::desc(
+ "Test fusion of generic operations with a control function."),
+ llvm::cl::init(false)};
+
ListOption<int64_t> collapseDimensions{
*this, "collapse-dimensions-control",
llvm::cl::desc("Test controlling dimension collapse pattern")};
@@ -257,6 +264,20 @@ struct TestLinalgElementwiseFusion
return;
}
+ if (blacklistOpsForReduction) {
+ RewritePatternSet fusionPatterns(context);
+ auto controlFn = [](OpOperand *operand) { return true; };
+ llvm::StringSet<> blacklist;
+ blacklist.insert("arith.addf");
+
+ linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn,
+ &blacklist);
+ if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+ std::move(fusionPatterns))))
+ return signalPassFailure();
+ return;
+ }
+
if (!collapseDimensions.empty()) {
SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
collapseDimensions.end());
More information about the Mlir-commits
mailing list