[Mlir-commits] [mlir] [mlir][linalg] Add merge consecutive linalg::reduceOp canonicalization (PR #195048)
Hocky Yudhiono
llvmlistbot at llvm.org
Thu May 7 23:24:52 PDT 2026
https://github.com/hockyy updated https://github.com/llvm/llvm-project/pull/195048
>From 967981dbf44386015e1d189c0a842a8fddf8627a Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Thu, 30 Apr 2026 18:00:25 +0800
Subject: [PATCH] [mlir][linalg] Add merge consecutive linalg::reduceOp
canonicalization
---
mlir/include/mlir/Dialect/Linalg/Passes.td | 10 ++
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 +
.../Transforms/MergeConsecutiveReduceOps.cpp | 115 ++++++++++++++++++
.../Linalg/merge-consecutive-reduce-ops.mlir | 45 +++++++
4 files changed, 171 insertions(+)
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/MergeConsecutiveReduceOps.cpp
create mode 100644 mlir/test/Dialect/Linalg/merge-consecutive-reduce-ops.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index b873f260e7d92..573b7dc097aa3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -168,6 +168,16 @@ def LinalgInlineScalarOperandsPass : Pass<"linalg-inline-scalar-operands"> {
];
}
+def LinalgMergeConsecutiveReduceOpsPass
+ : Pass<"linalg-merge-consecutive-reduce-ops"> {
+ let summary = "Merge consecutive linalg.reduce ops";
+ let description = [{
+ Merge consecutive `linalg.reduce` ops that use the same combiner into a
+ single `linalg.reduce` over the original input.
+ }];
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
let summary = "Fold transpose and broadcast ops into elementwise";
let dependentDialects = ["linalg::LinalgDialect"];
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index a2149478e4c2d..71818192fc4db 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
InlineScalarOperands.cpp
Interchange.cpp
Loops.cpp
+ MergeConsecutiveReduceOps.cpp
MorphOps.cpp
TransposeMatmul.cpp
ShardingInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MergeConsecutiveReduceOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/MergeConsecutiveReduceOps.cpp
new file mode 100644
index 0000000000000..acfbf099f55ec
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/MergeConsecutiveReduceOps.cpp
@@ -0,0 +1,115 @@
+//===- MergeConsecutiveReduceOps.cpp - Merge linalg.reduce ops ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/STLExtras.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGMERGECONSECUTIVEREDUCEOPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-merge-consecutive-reduce-ops"
+
+namespace {
+struct MergeConsecutiveReduceOp : OpRewritePattern<ReduceOp> {
+ using OpRewritePattern<ReduceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ReduceOp consumer,
+ PatternRewriter &rewriter) const override {
+ if (consumer.getNumDpsInputs() != 1) {
+ return rewriter.notifyMatchFailure(
+ consumer, "only supports second reduce op with one input");
+ }
+ Value input = consumer.getDpsInputs().front();
+ if (!input.hasOneUse()) {
+ return rewriter.notifyMatchFailure(
+ consumer, "does not support producer result with multiple users");
+ }
+ auto producer = input.getDefiningOp<ReduceOp>();
+ if (!producer) {
+ return rewriter.notifyMatchFailure(consumer,
+ "does not find consecutive reduces");
+ }
+ if (consumer->getBlock() != producer->getBlock()) {
+ return rewriter.notifyMatchFailure(
+ consumer, "does not support reduce in different blocks");
+ }
+ if (!OperationEquivalence::isRegionEquivalentTo(
+ &consumer.getRegion(), &producer.getRegion(),
+ OperationEquivalence::Flags::IgnoreLocations)) {
+ return rewriter.notifyMatchFailure(
+ consumer, "reduce operation regions are not equal");
+ }
+ SmallVector<unsigned> prodDims, consDims;
+ producer.getReductionDims(prodDims);
+ consumer.getReductionDims(consDims);
+ auto maxRank =
+ cast<ShapedType>(producer.getDpsInputs()[0].getType()).getRank();
+
+ auto dims = mergeConsecutiveReduceDims(prodDims, consDims, maxRank);
+ rewriter.setInsertionPointAfter(consumer);
+ auto newReduce = ReduceOp::create(
+ rewriter, consumer->getLoc(), TypeRange(consumer->getResults()),
+ producer.getInputs(), consumer.getInits(), dims);
+ Region &newRegion = newReduce.getRegion();
+ IRMapping mapping;
+ consumer.getRegion().cloneInto(&newRegion, newRegion.begin(), mapping);
+
+ rewriter.replaceOp(consumer, newReduce);
+ rewriter.eraseOp(producer);
+ return success();
+ }
+
+ /// Merge two reduce dims of consecutive reduce ops, returning the merged dims
+ /// that apply to the original reduce input.
+ SmallVector<int64_t> mergeConsecutiveReduceDims(ArrayRef<unsigned> prodDims,
+ ArrayRef<unsigned> consDims,
+ unsigned maxRank) const {
+ BitVector availableMask(maxRank, true);
+ for (unsigned dim : prodDims)
+ availableMask[dim] = false;
+ SmallVector<unsigned> remainingDimIndex;
+ for (unsigned i = 0; i < maxRank; i++)
+ if (availableMask[i])
+ remainingDimIndex.push_back(i);
+ SmallVector<int64_t> newDims(prodDims);
+ for (unsigned dim : consDims)
+ newDims.push_back(remainingDimIndex[dim]);
+ llvm::sort(newDims.begin(), newDims.end());
+ return newDims;
+ }
+};
+
+struct LinalgMergeConsecutiveReduceOpsPass
+ : public impl::LinalgMergeConsecutiveReduceOpsPassBase<
+ LinalgMergeConsecutiveReduceOpsPass> {
+ using impl::LinalgMergeConsecutiveReduceOpsPassBase<
+ LinalgMergeConsecutiveReduceOpsPass>::
+ LinalgMergeConsecutiveReduceOpsPassBase;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ RewritePatternSet patterns(op->getContext());
+ patterns.add<MergeConsecutiveReduceOp>(op->getContext());
+
+ if (failed(applyPatternsGreedily(op, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/test/Dialect/Linalg/merge-consecutive-reduce-ops.mlir b/mlir/test/Dialect/Linalg/merge-consecutive-reduce-ops.mlir
new file mode 100644
index 0000000000000..7a8050d112846
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/merge-consecutive-reduce-ops.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s -linalg-merge-consecutive-reduce-ops -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @merge_consecutive_reduce(
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<2x3x4x5xf32>
+// CHECK-SAME: %[[INIT:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK: %[[REDUCED:.+]] = linalg.reduce { arith.addf } ins(%[[INPUT]] : tensor<2x3x4x5xf32>) outs(%[[INIT]] : tensor<f32>) dimensions = [0, 1, 2, 3]
+// CHECK-NEXT: return %[[REDUCED]] : tensor<f32>
+func.func @merge_consecutive_reduce(
+ %input: tensor<2x3x4x5xf32>, %init: tensor<f32>) -> tensor<f32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %empty = tensor.empty() : tensor<3x5xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<3x5xf32>) -> tensor<3x5xf32>
+ %first_reduce = linalg.reduce { arith.addf }
+ ins(%input : tensor<2x3x4x5xf32>)
+ outs(%fill : tensor<3x5xf32>)
+ dimensions = [0, 2]
+ %second_reduce = linalg.reduce { arith.addf }
+ ins(%first_reduce : tensor<3x5xf32>)
+ outs(%init : tensor<f32>)
+ dimensions = [0, 1]
+ return %second_reduce : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @merge_consecutive_reduce_with_projected_dims(
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<2x3x4x5x6xf32>
+// CHECK-SAME: %[[INIT:[a-zA-Z0-9_]+]]: tensor<5xf32>
+// CHECK: %[[REDUCED:.+]] = linalg.reduce { arith.addf } ins(%[[INPUT]] : tensor<2x3x4x5x6xf32>) outs(%[[INIT]] : tensor<5xf32>) dimensions = [0, 1, 2, 4]
+// CHECK-NEXT: return %[[REDUCED]] : tensor<5xf32>
+func.func @merge_consecutive_reduce_with_projected_dims(
+ %input: tensor<2x3x4x5x6xf32>, %init: tensor<5xf32>) -> tensor<5xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %empty = tensor.empty() : tensor<3x4x5xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
+ %first_reduce = linalg.reduce { arith.addf }
+ ins(%input : tensor<2x3x4x5x6xf32>)
+ outs(%fill : tensor<3x4x5xf32>)
+ dimensions = [0, 4]
+ %second_reduce = linalg.reduce { arith.addf }
+ ins(%first_reduce : tensor<3x4x5xf32>)
+ outs(%init : tensor<5xf32>)
+ dimensions = [0, 1]
+ return %second_reduce : tensor<5xf32>
+}
More information about the Mlir-commits
mailing list