[Mlir-commits] [mlir] [mlir][linalg] Add merge consecutive linalg::reduceOp canonicalization (PR #195048)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 30 03:13:40 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Hocky Yudhiono (hockyy)

<details>
<summary>Changes</summary>

Introduce a canonicalization pattern where it would merge two consecutive and identical reduceOps into a single one.

---
Full diff: https://github.com/llvm/llvm-project/pull/195048.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+1) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+78) 
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+46) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 5998f736ced34..4503648deb994 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -387,6 +387,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
     MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
   }];
 
+  let hasCanonicalizer = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 27988a451173c..9b0e5c8518f33 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1803,6 +1803,84 @@ Speculation::Speculatability ReduceOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+namespace {
+
+struct MergeConsecutiveReduceOp : OpRewritePattern<linalg::ReduceOp> {
+  using OpRewritePattern<linalg::ReduceOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(linalg::ReduceOp consumer,
+                                PatternRewriter &rewriter) const override {
+    if (consumer.getNumDpsInputs() != 1) {
+      return rewriter.notifyMatchFailure(
+          consumer, "Only supports second reduce op with one input");
+    }
+    Value input = consumer.getDpsInputs().front();
+    if (!input.hasOneUse()) {
+      return rewriter.notifyMatchFailure(
+          consumer, "Does not support producer result with multiple users");
+    }
+    auto producer = input.getDefiningOp<linalg::ReduceOp>();
+    if (!producer) {
+      return rewriter.notifyMatchFailure(consumer,
+                                         "Does not find consecutive reduces");
+    }
+    if (consumer.getOperation()->getBlock() !=
+        producer.getOperation()->getBlock()) {
+      return rewriter.notifyMatchFailure(
+          consumer, "Does not support reduce in different blocks");
+    }
+    if (!OperationEquivalence::isRegionEquivalentTo(
+            &consumer.getRegion(), &producer.getRegion(),
+            OperationEquivalence::Flags::IgnoreLocations)) {
+      return rewriter.notifyMatchFailure(
+          consumer, "Reduce operation regions is not equal");
+    }
+    SmallVector<unsigned> prodDims, consDims;
+    producer.getReductionDims(prodDims);
+    consumer.getReductionDims(consDims);
+    auto maxRank =
+        cast<ShapedType>(producer.getDpsInputs()[0].getType()).getRank();
+
+    auto dims = mergeConsecutiveReduceDims(prodDims, consDims, maxRank);
+    rewriter.setInsertionPointAfter(consumer);
+    auto newReduce = linalg::ReduceOp::create(
+        rewriter, consumer->getLoc(), TypeRange(consumer->getResults()),
+        producer.getInputs(), consumer.getInits(), dims);
+    Region &newRegion = newReduce.getRegion();
+    IRMapping mapping;
+    consumer.getRegion().cloneInto(&newRegion, newRegion.begin(), mapping);
+
+    rewriter.replaceOp(consumer, newReduce);
+    rewriter.eraseOp(producer);
+    return success();
+  }
+
+  /// Merge two reduce dims of consecutive reduce ops, return the merged dims
+  /// that work on the origin reduce input.
+  SmallVector<int64_t> mergeConsecutiveReduceDims(ArrayRef<unsigned> prodDims,
+                                                  ArrayRef<unsigned> consDims,
+                                                  unsigned maxRank) const {
+    BitVector availableMask(maxRank, true);
+    for (unsigned dim : prodDims)
+      availableMask[dim] = false;
+    SmallVector<unsigned> remainingDimIndex;
+    for (unsigned i = 0; i < maxRank; i++)
+      if (availableMask[i])
+        remainingDimIndex.push_back(i);
+    SmallVector<int64_t> newDims(prodDims);
+    for (unsigned dim : consDims)
+      newDims.push_back(remainingDimIndex[dim]);
+    llvm::sort(newDims.begin(), newDims.end());
+    return newDims;
+  }
+};
+
+} // namespace
+
+void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                           MLIRContext *context) {
+  results.add<MergeConsecutiveReduceOp>(context);
+}
+
 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
                                           NamedAttrList &attributes,
                                           StringRef attributeName) {
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 019b7433b2777..1d0ed5bd7c6df 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -2258,3 +2258,49 @@ func.func @no_fold_pack_cast_inner_tile_inlined_mismatch(%arg0: tensor<8x3xi32>,
     into %dest : tensor<?x?xi32> -> tensor<?x3x?x1xi32>
   return %pack : tensor<?x3x?x1xi32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @fold_consecutive_reduce(
+// CHECK-SAME:    %[[INPUT:[a-zA-Z0-9_]+]]: tensor<2x3x4x5xf32>
+// CHECK-SAME:    %[[INIT:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK:         %[[REDUCED:.+]] = linalg.reduce { arith.addf } ins(%[[INPUT]] : tensor<2x3x4x5xf32>) outs(%[[INIT]] : tensor<f32>) dimensions = [0, 1, 2, 3]
+// CHECK-NEXT:    return %[[REDUCED]] : tensor<f32>
+func.func @fold_consecutive_reduce(
+    %input: tensor<2x3x4x5xf32>, %init: tensor<f32>) -> tensor<f32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %empty = tensor.empty() : tensor<3x5xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<3x5xf32>) -> tensor<3x5xf32>
+  %first_reduce = linalg.reduce { arith.addf }
+      ins(%input : tensor<2x3x4x5xf32>)
+      outs(%fill : tensor<3x5xf32>)
+      dimensions = [0, 2]
+  %second_reduce = linalg.reduce { arith.addf }
+      ins(%first_reduce : tensor<3x5xf32>)
+      outs(%init : tensor<f32>)
+      dimensions = [0, 1]
+  return %second_reduce : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_consecutive_reduce_with_projected_dims(
+// CHECK-SAME:    %[[INPUT:[a-zA-Z0-9_]+]]: tensor<2x3x4x5x6xf32>
+// CHECK-SAME:    %[[INIT:[a-zA-Z0-9_]+]]: tensor<5xf32>
+// CHECK:         %[[REDUCED:.+]] = linalg.reduce { arith.addf } ins(%[[INPUT]] : tensor<2x3x4x5x6xf32>) outs(%[[INIT]] : tensor<5xf32>) dimensions = [0, 1, 2, 4]
+// CHECK-NEXT:    return %[[REDUCED]] : tensor<5xf32>
+func.func @fold_consecutive_reduce_with_projected_dims(
+    %input: tensor<2x3x4x5x6xf32>, %init: tensor<5xf32>) -> tensor<5xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %empty = tensor.empty() : tensor<3x4x5xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
+  %first_reduce = linalg.reduce { arith.addf }
+      ins(%input : tensor<2x3x4x5x6xf32>)
+      outs(%fill : tensor<3x4x5xf32>)
+      dimensions = [0, 4]
+  %second_reduce = linalg.reduce { arith.addf }
+      ins(%first_reduce : tensor<3x4x5xf32>)
+      outs(%init : tensor<5xf32>)
+      dimensions = [0, 1]
+  return %second_reduce : tensor<5xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/195048


More information about the Mlir-commits mailing list