[Mlir-commits] [mlir] [mlir][linalg] Add merge consecutive linalg::reduceOp canonicalization (PR #195048)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 30 03:13:40 PDT 2026
llvmorg-github-actions[bot] wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Hocky Yudhiono (hockyy)
<details>
<summary>Changes</summary>
Introduce a canonicalization pattern where it would merge two consecutive and identical reduceOps into a single one.
---
Full diff: https://github.com/llvm/llvm-project/pull/195048.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+1)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+78)
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+46)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 5998f736ced34..4503648deb994 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -387,6 +387,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
}];
+ let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 27988a451173c..9b0e5c8518f33 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1803,6 +1803,84 @@ Speculation::Speculatability ReduceOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+namespace {
+
+struct MergeConsecutiveReduceOp : OpRewritePattern<linalg::ReduceOp> {
+ using OpRewritePattern<linalg::ReduceOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(linalg::ReduceOp consumer,
+ PatternRewriter &rewriter) const override {
+ if (consumer.getNumDpsInputs() != 1) {
+ return rewriter.notifyMatchFailure(
+ consumer, "Only supports second reduce op with one input");
+ }
+ Value input = consumer.getDpsInputs().front();
+ if (!input.hasOneUse()) {
+ return rewriter.notifyMatchFailure(
+ consumer, "Does not support producer result with multiple users");
+ }
+ auto producer = input.getDefiningOp<linalg::ReduceOp>();
+ if (!producer) {
+ return rewriter.notifyMatchFailure(consumer,
+ "Does not find consecutive reduces");
+ }
+ if (consumer.getOperation()->getBlock() !=
+ producer.getOperation()->getBlock()) {
+ return rewriter.notifyMatchFailure(
+ consumer, "Does not support reduce in different blocks");
+ }
+ if (!OperationEquivalence::isRegionEquivalentTo(
+ &consumer.getRegion(), &producer.getRegion(),
+ OperationEquivalence::Flags::IgnoreLocations)) {
+ return rewriter.notifyMatchFailure(
+ consumer, "Reduce operation regions is not equal");
+ }
+ SmallVector<unsigned> prodDims, consDims;
+ producer.getReductionDims(prodDims);
+ consumer.getReductionDims(consDims);
+ auto maxRank =
+ cast<ShapedType>(producer.getDpsInputs()[0].getType()).getRank();
+
+ auto dims = mergeConsecutiveReduceDims(prodDims, consDims, maxRank);
+ rewriter.setInsertionPointAfter(consumer);
+ auto newReduce = linalg::ReduceOp::create(
+ rewriter, consumer->getLoc(), TypeRange(consumer->getResults()),
+ producer.getInputs(), consumer.getInits(), dims);
+ Region &newRegion = newReduce.getRegion();
+ IRMapping mapping;
+ consumer.getRegion().cloneInto(&newRegion, newRegion.begin(), mapping);
+
+ rewriter.replaceOp(consumer, newReduce);
+ rewriter.eraseOp(producer);
+ return success();
+ }
+
+ /// Merge two reduce dims of consecutive reduce ops, return the merged dims
+ /// that work on the origin reduce input.
+ SmallVector<int64_t> mergeConsecutiveReduceDims(ArrayRef<unsigned> prodDims,
+ ArrayRef<unsigned> consDims,
+ unsigned maxRank) const {
+ BitVector availableMask(maxRank, true);
+ for (unsigned dim : prodDims)
+ availableMask[dim] = false;
+ SmallVector<unsigned> remainingDimIndex;
+ for (unsigned i = 0; i < maxRank; i++)
+ if (availableMask[i])
+ remainingDimIndex.push_back(i);
+ SmallVector<int64_t> newDims(prodDims);
+ for (unsigned dim : consDims)
+ newDims.push_back(remainingDimIndex[dim]);
+ llvm::sort(newDims.begin(), newDims.end());
+ return newDims;
+ }
+};
+
+} // namespace
+
+void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<MergeConsecutiveReduceOp>(context);
+}
+
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
NamedAttrList &attributes,
StringRef attributeName) {
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 019b7433b2777..1d0ed5bd7c6df 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -2258,3 +2258,49 @@ func.func @no_fold_pack_cast_inner_tile_inlined_mismatch(%arg0: tensor<8x3xi32>,
into %dest : tensor<?x?xi32> -> tensor<?x3x?x1xi32>
return %pack : tensor<?x3x?x1xi32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_consecutive_reduce(
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<2x3x4x5xf32>
+// CHECK-SAME: %[[INIT:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK: %[[REDUCED:.+]] = linalg.reduce { arith.addf } ins(%[[INPUT]] : tensor<2x3x4x5xf32>) outs(%[[INIT]] : tensor<f32>) dimensions = [0, 1, 2, 3]
+// CHECK-NEXT: return %[[REDUCED]] : tensor<f32>
+func.func @fold_consecutive_reduce(
+ %input: tensor<2x3x4x5xf32>, %init: tensor<f32>) -> tensor<f32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %empty = tensor.empty() : tensor<3x5xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<3x5xf32>) -> tensor<3x5xf32>
+ %first_reduce = linalg.reduce { arith.addf }
+ ins(%input : tensor<2x3x4x5xf32>)
+ outs(%fill : tensor<3x5xf32>)
+ dimensions = [0, 2]
+ %second_reduce = linalg.reduce { arith.addf }
+ ins(%first_reduce : tensor<3x5xf32>)
+ outs(%init : tensor<f32>)
+ dimensions = [0, 1]
+ return %second_reduce : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_consecutive_reduce_with_projected_dims(
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<2x3x4x5x6xf32>
+// CHECK-SAME: %[[INIT:[a-zA-Z0-9_]+]]: tensor<5xf32>
+// CHECK: %[[REDUCED:.+]] = linalg.reduce { arith.addf } ins(%[[INPUT]] : tensor<2x3x4x5x6xf32>) outs(%[[INIT]] : tensor<5xf32>) dimensions = [0, 1, 2, 4]
+// CHECK-NEXT: return %[[REDUCED]] : tensor<5xf32>
+func.func @fold_consecutive_reduce_with_projected_dims(
+ %input: tensor<2x3x4x5x6xf32>, %init: tensor<5xf32>) -> tensor<5xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %empty = tensor.empty() : tensor<3x4x5xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
+ %first_reduce = linalg.reduce { arith.addf }
+ ins(%input : tensor<2x3x4x5x6xf32>)
+ outs(%fill : tensor<3x4x5xf32>)
+ dimensions = [0, 4]
+ %second_reduce = linalg.reduce { arith.addf }
+ ins(%first_reduce : tensor<3x4x5xf32>)
+ outs(%init : tensor<5xf32>)
+ dimensions = [0, 1]
+ return %second_reduce : tensor<5xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/195048
More information about the Mlir-commits
mailing list