[Mlir-commits] [mlir] [mlir][linalg] Add merge consecutive linalg::reduceOp canonicalization (PR #195048)

Hocky Yudhiono llvmlistbot at llvm.org
Thu May 7 23:24:52 PDT 2026


https://github.com/hockyy updated https://github.com/llvm/llvm-project/pull/195048

>From 967981dbf44386015e1d189c0a842a8fddf8627a Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Thu, 30 Apr 2026 18:00:25 +0800
Subject: [PATCH] [mlir][linalg] Add merge consecutive linalg::reduceOp
 canonicalization

---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |  10 ++
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Transforms/MergeConsecutiveReduceOps.cpp  | 115 ++++++++++++++++++
 .../Linalg/merge-consecutive-reduce-ops.mlir  |  45 +++++++
 4 files changed, 171 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/MergeConsecutiveReduceOps.cpp
 create mode 100644 mlir/test/Dialect/Linalg/merge-consecutive-reduce-ops.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index b873f260e7d92..573b7dc097aa3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -168,6 +168,16 @@ def LinalgInlineScalarOperandsPass : Pass<"linalg-inline-scalar-operands"> {
   ];
 }
 
+def LinalgMergeConsecutiveReduceOpsPass
+    : Pass<"linalg-merge-consecutive-reduce-ops"> {
+  let summary = "Merge consecutive linalg.reduce ops";
+  let description = [{
+    Merge consecutive `linalg.reduce` ops that use the same combiner into a
+    single `linalg.reduce` over the original input.
+  }];
+  let dependentDialects = ["linalg::LinalgDialect"];
+}
+
 def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
   let summary = "Fold transpose and broadcast ops into elementwise";
   let dependentDialects = ["linalg::LinalgDialect"];
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index a2149478e4c2d..71818192fc4db 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   InlineScalarOperands.cpp
   Interchange.cpp
   Loops.cpp
+  MergeConsecutiveReduceOps.cpp
   MorphOps.cpp
   TransposeMatmul.cpp
   ShardingInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MergeConsecutiveReduceOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/MergeConsecutiveReduceOps.cpp
new file mode 100644
index 0000000000000..acfbf099f55ec
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/MergeConsecutiveReduceOps.cpp
@@ -0,0 +1,115 @@
+//===- MergeConsecutiveReduceOps.cpp - Merge linalg.reduce ops ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/STLExtras.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGMERGECONSECUTIVEREDUCEOPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-merge-consecutive-reduce-ops"
+
+namespace {
+struct MergeConsecutiveReduceOp : OpRewritePattern<ReduceOp> {
+  using OpRewritePattern<ReduceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ReduceOp consumer,
+                                PatternRewriter &rewriter) const override {
+    if (consumer.getNumDpsInputs() != 1) {
+      return rewriter.notifyMatchFailure(
+          consumer, "only supports second reduce op with one input");
+    }
+    Value input = consumer.getDpsInputs().front();
+    if (!input.hasOneUse()) {
+      return rewriter.notifyMatchFailure(
+          consumer, "does not support producer result with multiple users");
+    }
+    auto producer = input.getDefiningOp<ReduceOp>();
+    if (!producer) {
+      return rewriter.notifyMatchFailure(consumer,
+                                         "does not find consecutive reduces");
+    }
+    if (consumer->getBlock() != producer->getBlock()) {
+      return rewriter.notifyMatchFailure(
+          consumer, "does not support reduce in different blocks");
+    }
+    if (!OperationEquivalence::isRegionEquivalentTo(
+            &consumer.getRegion(), &producer.getRegion(),
+            OperationEquivalence::Flags::IgnoreLocations)) {
+      return rewriter.notifyMatchFailure(
+          consumer, "reduce operation regions are not equal");
+    }
+    SmallVector<unsigned> prodDims, consDims;
+    producer.getReductionDims(prodDims);
+    consumer.getReductionDims(consDims);
+    auto maxRank =
+        cast<ShapedType>(producer.getDpsInputs()[0].getType()).getRank();
+
+    auto dims = mergeConsecutiveReduceDims(prodDims, consDims, maxRank);
+    rewriter.setInsertionPointAfter(consumer);
+    auto newReduce = ReduceOp::create(
+        rewriter, consumer->getLoc(), TypeRange(consumer->getResults()),
+        producer.getInputs(), consumer.getInits(), dims);
+    Region &newRegion = newReduce.getRegion();
+    IRMapping mapping;
+    consumer.getRegion().cloneInto(&newRegion, newRegion.begin(), mapping);
+
+    rewriter.replaceOp(consumer, newReduce);
+    rewriter.eraseOp(producer);
+    return success();
+  }
+
+  /// Merge two reduce dims of consecutive reduce ops, returning the merged dims
+  /// that apply to the original reduce input.
+  SmallVector<int64_t> mergeConsecutiveReduceDims(ArrayRef<unsigned> prodDims,
+                                                  ArrayRef<unsigned> consDims,
+                                                  unsigned maxRank) const {
+    BitVector availableMask(maxRank, true);
+    for (unsigned dim : prodDims)
+      availableMask[dim] = false;
+    SmallVector<unsigned> remainingDimIndex;
+    for (unsigned i = 0; i < maxRank; i++)
+      if (availableMask[i])
+        remainingDimIndex.push_back(i);
+    SmallVector<int64_t> newDims(prodDims);
+    for (unsigned dim : consDims)
+      newDims.push_back(remainingDimIndex[dim]);
+    llvm::sort(newDims.begin(), newDims.end());
+    return newDims;
+  }
+};
+
+struct LinalgMergeConsecutiveReduceOpsPass
+    : public impl::LinalgMergeConsecutiveReduceOpsPassBase<
+          LinalgMergeConsecutiveReduceOpsPass> {
+  using impl::LinalgMergeConsecutiveReduceOpsPassBase<
+      LinalgMergeConsecutiveReduceOpsPass>::
+      LinalgMergeConsecutiveReduceOpsPassBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(op->getContext());
+    patterns.add<MergeConsecutiveReduceOp>(op->getContext());
+
+    if (failed(applyPatternsGreedily(op, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
diff --git a/mlir/test/Dialect/Linalg/merge-consecutive-reduce-ops.mlir b/mlir/test/Dialect/Linalg/merge-consecutive-reduce-ops.mlir
new file mode 100644
index 0000000000000..7a8050d112846
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/merge-consecutive-reduce-ops.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s -linalg-merge-consecutive-reduce-ops -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @merge_consecutive_reduce(
+// CHECK-SAME:    %[[INPUT:[a-zA-Z0-9_]+]]: tensor<2x3x4x5xf32>
+// CHECK-SAME:    %[[INIT:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK:         %[[REDUCED:.+]] = linalg.reduce { arith.addf } ins(%[[INPUT]] : tensor<2x3x4x5xf32>) outs(%[[INIT]] : tensor<f32>) dimensions = [0, 1, 2, 3]
+// CHECK-NEXT:    return %[[REDUCED]] : tensor<f32>
+func.func @merge_consecutive_reduce(
+    %input: tensor<2x3x4x5xf32>, %init: tensor<f32>) -> tensor<f32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %empty = tensor.empty() : tensor<3x5xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<3x5xf32>) -> tensor<3x5xf32>
+  %first_reduce = linalg.reduce { arith.addf }
+      ins(%input : tensor<2x3x4x5xf32>)
+      outs(%fill : tensor<3x5xf32>)
+      dimensions = [0, 2]
+  %second_reduce = linalg.reduce { arith.addf }
+      ins(%first_reduce : tensor<3x5xf32>)
+      outs(%init : tensor<f32>)
+      dimensions = [0, 1]
+  return %second_reduce : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @merge_consecutive_reduce_with_projected_dims(
+// CHECK-SAME:    %[[INPUT:[a-zA-Z0-9_]+]]: tensor<2x3x4x5x6xf32>
+// CHECK-SAME:    %[[INIT:[a-zA-Z0-9_]+]]: tensor<5xf32>
+// CHECK:         %[[REDUCED:.+]] = linalg.reduce { arith.addf } ins(%[[INPUT]] : tensor<2x3x4x5x6xf32>) outs(%[[INIT]] : tensor<5xf32>) dimensions = [0, 1, 2, 4]
+// CHECK-NEXT:    return %[[REDUCED]] : tensor<5xf32>
+func.func @merge_consecutive_reduce_with_projected_dims(
+    %input: tensor<2x3x4x5x6xf32>, %init: tensor<5xf32>) -> tensor<5xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %empty = tensor.empty() : tensor<3x4x5xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
+  %first_reduce = linalg.reduce { arith.addf }
+      ins(%input : tensor<2x3x4x5x6xf32>)
+      outs(%fill : tensor<3x4x5xf32>)
+      dimensions = [0, 4]
+  %second_reduce = linalg.reduce { arith.addf }
+      ins(%first_reduce : tensor<3x4x5xf32>)
+      outs(%init : tensor<5xf32>)
+      dimensions = [0, 1]
+  return %second_reduce : tensor<5xf32>
+}



More information about the Mlir-commits mailing list