[Mlir-commits] [mlir] Reland "[mlir][reducer] Add eraseRedundantBlocksInRegion to reduction-tree pass" (PR #191961)
lonely eagle
llvmlistbot at llvm.org
Thu Apr 30 00:50:47 PDT 2026
https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/191961
>From 86b363cf10c2152b3541eb0eaa9099d997a48c74 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Fri, 20 Mar 2026 16:48:40 +0000
Subject: [PATCH 1/4] ea This is a combination of 4 commits.
add eraseRedundantBlocksInRegion.
---
.../Dialect/ControlFlow/IR/ControlFlowOps.td | 10 +-
.../mlir/Interfaces/ControlFlowInterfaces.td | 9 ++
mlir/include/mlir/Reducer/ReductionNode.h | 3 +
.../Dialect/ControlFlow/IR/ControlFlowOps.cpp | 24 ++++
mlir/lib/Reducer/ReductionNode.cpp | 10 ++
mlir/lib/Reducer/ReductionTreePass.cpp | 115 ++++++++++++++++++
mlir/test/mlir-reduce/reduction-tree.mlir | 66 ++++++++++
7 files changed, 233 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
index a441fd82546e3..ddea3a7eae590 100644
--- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
@@ -65,7 +65,8 @@ def AssertOp : CF_Op<"assert",
//===----------------------------------------------------------------------===//
def BranchOp : CF_Op<"br", [
- DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
+ DeclareOpInterfaceMethods<BranchOpInterface,
+ ["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
Pure, Terminator
]> {
let summary = "Branch operation";
@@ -114,8 +115,8 @@ def BranchOp : CF_Op<"br", [
def CondBranchOp
: CF_Op<"cond_br", [AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<
- BranchOpInterface, ["getSuccessorForOperands"]>,
+ DeclareOpInterfaceMethods<BranchOpInterface,
+ ["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
WeightedBranchOpInterface, Pure, Terminator]> {
let summary = "Conditional branch operation";
let description = [{
@@ -241,7 +242,8 @@ def CondBranchOp
def SwitchOp : CF_Op<"switch",
[AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
+ DeclareOpInterfaceMethods<BranchOpInterface,
+ ["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
Pure, Terminator]> {
let summary = "Switch operation";
let description = [{
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 06fa724e05fab..d32be0c63acc7 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -98,6 +98,15 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
(ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
[{ return lhs == rhs; }]
>,
+ InterfaceMethod<[{
+ This method is called to returns the operands of this operation that
+ are passed to the specified successor's block arguments. If the successor
+ is not valid for this operation, or no operands are forwarded, an empty
+ ValueRange is returned.
+ }],
+ "ValueRange", "getSuccessorForwardOperands",
+ (ins "Block *":$successor), [{}],[{ return {};}]
+ >,
];
let verify = [{
diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h
index 6ca4e13d159ac..125a7c6f6f5e7 100644
--- a/mlir/include/mlir/Reducer/ReductionNode.h
+++ b/mlir/include/mlir/Reducer/ReductionNode.h
@@ -90,6 +90,9 @@ class ReductionNode {
/// corresponding region.
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
+ LogicalResult initialize(ModuleOp parentModule, Region &parentRegion,
+ IRMapping &mapper);
+
private:
/// A custom BFS iterator. The difference between
/// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 435c37bc95aac..f6eb0f05911b8 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -296,6 +296,12 @@ Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
return getDest();
}
+ValueRange BranchOp::getSuccessorForwardOperands(Block *successor) {
+ if (successor == getDest())
+ return getDestOperands();
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// CondBranchOp
//===----------------------------------------------------------------------===//
@@ -583,6 +589,14 @@ Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
return nullptr;
}
+ValueRange CondBranchOp::getSuccessorForwardOperands(Block *successor) {
+ if (successor == getTrueDest())
+ return getTrueOperands();
+ else if (successor == getFalseDest())
+ return getFalseOperands();
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
@@ -1034,6 +1048,16 @@ void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
.add<SimplifyUniformBlockArguments>(context);
}
+ValueRange SwitchOp::getSuccessorForwardOperands(Block *successor) {
+ if (successor == getDefaultDestination())
+ return getDefaultOperands();
+ SuccessorRange caseDests = getCaseDestinations();
+ auto it = llvm::find(caseDests, successor);
+ if (it == caseDests.end())
+ return {};
+ return getCaseOperands(std::distance(caseDests.begin(), it));
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Reducer/ReductionNode.cpp b/mlir/lib/Reducer/ReductionNode.cpp
index 11aeaf77b4642..897aae0becf33 100644
--- a/mlir/lib/Reducer/ReductionNode.cpp
+++ b/mlir/lib/Reducer/ReductionNode.cpp
@@ -45,6 +45,16 @@ LogicalResult ReductionNode::initialize(ModuleOp parentModule,
return success();
}
+LogicalResult ReductionNode::initialize(ModuleOp parentModule,
+ Region &targetRegion,
+ IRMapping &mapper) {
+ module = cast<ModuleOp>(parentModule->clone(mapper));
+ // Use the first block of targetRegion to locate the cloned region.
+ Block *block = mapper.lookup(&*targetRegion.begin());
+ region = block->getParent();
+ return success();
+}
+
/// If we haven't explored any variants from this node, we will create N
/// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
/// max element in `ranges` and create 2 new variants for each call.
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 83497143d9669..8a18c65fdacca 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -14,7 +14,10 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Reducer/Passes.h"
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
@@ -24,6 +27,9 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Allocator.h"
+#include "llvm/Support/DebugLog.h"
+
+#define DEBUG_TYPE "reduction-tree"
namespace mlir {
#define GEN_PASS_DEF_REDUCTIONTREEPASS
@@ -184,6 +190,113 @@ static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region ®ion,
return failure();
}
+// Returns the first branching terminator (cond_br, switch, etc.) found in the
+// region.
+static Operation *getBranchTerminatorInRegion(Region ®ion) {
+ for (Block &block : region.getBlocks()) {
+ if (block.getNumSuccessors() > 1)
+ return block.getTerminator();
+ }
+ return {};
+}
+
+/// Reduces the control flow in a region by iteratively forcing branching
+/// terminators to point to a single successor. It evaluates each potential
+/// branch path and commits the reduction that results in the smallest
+/// "interesting" module.
+static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
+ Region ®ion,
+ const Tester &test) {
+ std::pair<Tester::Interestingness, size_t> initStatus =
+ test.isInteresting(module);
+
+ // While exploring the reduction tree, we always branch from an interesting
+ // node. Thus the root node must be interesting.
+ if (initStatus.first != Tester::Interestingness::True)
+ return module.emitWarning() << "uninterested module will not be reduced";
+ llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
+
+ // We set the simplification level to Aggressive to enable block merging.
+ GreedyRewriteConfig config;
+ config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Aggressive);
+ config.setUseTopDownTraversal(true);
+
+ // Populate canonicalization patterns for cf ops. When all targets of a
+ // 'cf.cond_br' or 'cf.switch' point to the same block, they will be
+ // canonicalized into a 'cf.br'.
+ auto context = region.getContext();
+ RewritePatternSet patterns(context);
+ cf::BranchOp::getCanonicalizationPatterns(patterns, context);
+ cf::CondBranchOp::getCanonicalizationPatterns(patterns, context);
+ cf::SwitchOp::getCanonicalizationPatterns(patterns, context);
+ FrozenRewritePatternSet fPatterns = std::move(patterns);
+
+ ReductionNode *smallestNode = nullptr;
+ mlir::OpBuilder b(context);
+ while (Operation *branchTerminator = getBranchTerminatorInRegion(region)) {
+ size_t numSuccessor = branchTerminator->getNumSuccessors();
+ // We allocate memory on the heap because the object will be assigned to
+ // 'smallestNode'.
+ ReductionNode *root = allocator.Allocate();
+ std::vector<ReductionNode::Range> ranges{
+ {0, std::distance(region.op_begin(), region.op_end())}};
+
+ // Iterate through each successor of the branching terminator to try
+ // reducing the control flow to a single-path execution.
+ int branchIdx = -1;
+ for (int i = 0, e = numSuccessor; i < e; ++i) {
+ new (root) ReductionNode(nullptr, ranges, allocator);
+ mlir::IRMapping mapper;
+ if (failed(root->initialize(module, region, mapper)))
+ llvm_unreachable("unexpected initialization failure");
+ Operation *tergetTerminator = mapper.lookup(branchTerminator);
+ Block *selectedBlock = tergetTerminator->getSuccessor(i);
+ auto branchOp = cast<BranchOpInterface>(tergetTerminator);
+ ValueRange selectedBlockOperands =
+ branchOp.getSuccessorForwardOperands(selectedBlock);
+ b.setInsertionPointAfter(tergetTerminator);
+ cf::BranchOp::create(b, tergetTerminator->getLoc(), selectedBlock,
+ selectedBlockOperands);
+ tergetTerminator->erase();
+
+ // Apply canonicalization patterns to collapse the now-redundant branches
+ (void)applyPatternsGreedily(root->getRegion().getParentOp(), fPatterns,
+ config);
+ root->update(test.isInteresting(root->getModule()));
+
+ // Track the smallest "interesting" version of the IR found so far.
+ if (root->isInteresting() == Tester::Interestingness::True &&
+ (smallestNode == nullptr ||
+ root->getSize() < smallestNode->getSize())) {
+ smallestNode = root;
+ branchIdx = i;
+ }
+ }
+
+ // If an interesting reduced branch was found, commit the change to the
+ // original region and re-apply patterns for a final cleanup.
+ if (branchIdx != -1) {
+ Block *selectedBlock = branchTerminator->getSuccessor(branchIdx);
+ auto branchOp = cast<BranchOpInterface>(branchTerminator);
+ ValueRange selectedBlockOperands =
+ branchOp.getSuccessorForwardOperands(selectedBlock);
+ b.setInsertionPointAfter(branchTerminator);
+ cf::BranchOp::create(b, branchTerminator->getLoc(), selectedBlock,
+ selectedBlockOperands);
+ branchTerminator->erase();
+ (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
+ }
+ }
+
+ // If no branching terminators were found (skipping the while loop),
+ // there might still be opportunities for linear block merging or
+ // We apply patterns here as a final cleanup to ensure the region is fully
+ // simplified.
+ if (smallestNode == nullptr)
+ (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
+ return success();
+}
+
template <typename IteratorType>
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
const FrozenRewritePatternSet &patterns,
@@ -196,6 +309,8 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
if (succeeded(eraseAllOpsInRegion(module, region, test)))
return success();
+ (void)eraseRedundantBlocksInRegion(module, region, test);
+
// In the second phase, we don't apply any patterns so that we only select the
// range of operations to keep to the module stay interesting.
if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
diff --git a/mlir/test/mlir-reduce/reduction-tree.mlir b/mlir/test/mlir-reduce/reduction-tree.mlir
index 2aee89741b42b..b693765fbed53 100644
--- a/mlir/test/mlir-reduce/reduction-tree.mlir
+++ b/mlir/test/mlir-reduce/reduction-tree.mlir
@@ -58,3 +58,69 @@ func.func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func.func @simple5() {
return
}
+
+// -----
+
+// CHECK-LABEL: func @br_reduction
+// CHECK-SAME: %[[ARG0:.*]]: i1,
+// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
+// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) {
+func.func @br_reduction(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+ cf.cond_br %arg0, ^bb1, ^bb2
+^bb1:
+ cf.br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+ %0 = memref.alloc() : memref<2xf32>
+ cf.br ^bb3(%0 : memref<2xf32>)
+^bb3(%1: memref<2xf32>):
+ "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+ return
+}
+// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
+
+// -----
+
+// CHECK-LABEL: func @br_reduction_loop
+// CHECK-SAME: %[[ARG0:.*]]: i1,
+// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
+// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) {
+func.func @br_reduction_loop(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+ cf.cond_br %arg0, ^bb1, ^bb2
+^bb1:
+ cf.br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+ %0 = memref.alloc() : memref<2xf32>
+ cf.br ^bb3(%0 : memref<2xf32>)
+^bb3(%1: memref<2xf32>):
+ "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+ cf.cond_br %arg0, ^bb3(%1: memref<2xf32>), ^bb4
+^bb4:
+ return
+}
+// CHECK-NEXT: cf.br ^bb1(%[[ARG1]] : memref<2xf32>)
+// CHECK-NEXT: ^bb1(%[[VAL_0:.*]]: memref<2xf32>):
+// CHECK-NEXT: "test.op_crash"(%[[VAL_0]], %[[ARG2]])
+// CHECK-NEXT: cf.br ^bb1(%[[VAL_0]] : memref<2xf32>)
+
+// -----
+
+// CHECK-LABEL: func @switch_reduction
+// CHECK-SAME: %[[ARG0:.*]]: i32,
+// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
+// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) {
+func.func @switch_reduction(%arg0: i32, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+ cf.switch %arg0 : i32, [
+ default: ^bb3(%arg1 : memref<2xf32>),
+ 0: ^bb1,
+ 1: ^bb2
+ ]
+^bb1:
+ cf.br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+ %0 = memref.alloc() : memref<2xf32>
+ cf.br ^bb3(%0 : memref<2xf32>)
+^bb3(%1: memref<2xf32>):
+ "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+ return
+}
+// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
\ No newline at end of file
>From 422f61c09c3481e9ca0bf2944544a7c3947f00c9 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sun, 29 Mar 2026 15:47:49 +0000
Subject: [PATCH 2/4] fix undefined symbol and memory leak issues.
---
mlir/lib/Reducer/CMakeLists.txt | 1 +
mlir/lib/Reducer/ReductionTreePass.cpp | 7 +++----
mlir/test/mlir-reduce/reduction-tree.mlir | 9 ++++-----
3 files changed, 8 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt
index 68864e373c993..b18a4bca04fcb 100644
--- a/mlir/lib/Reducer/CMakeLists.txt
+++ b/mlir/lib/Reducer/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_library(MLIRReduce
MLIRPass
MLIRRewrite
MLIRTransformUtils
+ MLIRControlFlowDialect
DEPENDS
MLIRReducerIncGen
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 8a18c65fdacca..12358f7d71688 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -235,16 +235,15 @@ static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
mlir::OpBuilder b(context);
while (Operation *branchTerminator = getBranchTerminatorInRegion(region)) {
size_t numSuccessor = branchTerminator->getNumSuccessors();
- // We allocate memory on the heap because the object will be assigned to
- // 'smallestNode'.
- ReductionNode *root = allocator.Allocate();
std::vector<ReductionNode::Range> ranges{
{0, std::distance(region.op_begin(), region.op_end())}};
-
// Iterate through each successor of the branching terminator to try
// reducing the control flow to a single-path execution.
int branchIdx = -1;
for (int i = 0, e = numSuccessor; i < e; ++i) {
+ // We allocate memory on the heap because the object will be assigned to
+ // 'smallestNode'.
+ ReductionNode *root = allocator.Allocate();
new (root) ReductionNode(nullptr, ranges, allocator);
mlir::IRMapping mapper;
if (failed(root->initialize(module, region, mapper)))
diff --git a/mlir/test/mlir-reduce/reduction-tree.mlir b/mlir/test/mlir-reduce/reduction-tree.mlir
index b693765fbed53..b053a111e9a16 100644
--- a/mlir/test/mlir-reduce/reduction-tree.mlir
+++ b/mlir/test/mlir-reduce/reduction-tree.mlir
@@ -85,6 +85,7 @@ func.func @br_reduction(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) {
func.func @br_reduction_loop(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+ // select ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
cf.br ^bb3(%arg1 : memref<2xf32>)
@@ -93,14 +94,12 @@ func.func @br_reduction_loop(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf3
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
"test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+ // select ^bb4
cf.cond_br %arg0, ^bb3(%1: memref<2xf32>), ^bb4
^bb4:
return
}
-// CHECK-NEXT: cf.br ^bb1(%[[ARG1]] : memref<2xf32>)
-// CHECK-NEXT: ^bb1(%[[VAL_0:.*]]: memref<2xf32>):
-// CHECK-NEXT: "test.op_crash"(%[[VAL_0]], %[[ARG2]])
-// CHECK-NEXT: cf.br ^bb1(%[[VAL_0]] : memref<2xf32>)
+// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
// -----
@@ -123,4 +122,4 @@ func.func @switch_reduction(%arg0: i32, %arg1: memref<2xf32>, %arg2: memref<2xf3
"test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
}
-// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
\ No newline at end of file
+// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
>From 36c24d289a132ee6f9575c5b779c322a37a752b2 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Mon, 13 Apr 2026 02:32:12 +0000
Subject: [PATCH 3/4] use getSuccessorOperands and remove
getSuccessorForwardOperands.
---
.../Dialect/ControlFlow/IR/ControlFlowOps.td | 10 ++++----
.../mlir/Interfaces/ControlFlowInterfaces.td | 9 -------
.../Dialect/ControlFlow/IR/ControlFlowOps.cpp | 24 -------------------
mlir/lib/Reducer/ReductionTreePass.cpp | 12 +++++-----
mlir/test/mlir-reduce/reduction-tree.mlir | 16 ++++++-------
5 files changed, 18 insertions(+), 53 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
index ddea3a7eae590..a441fd82546e3 100644
--- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
@@ -65,8 +65,7 @@ def AssertOp : CF_Op<"assert",
//===----------------------------------------------------------------------===//
def BranchOp : CF_Op<"br", [
- DeclareOpInterfaceMethods<BranchOpInterface,
- ["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
+ DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
Pure, Terminator
]> {
let summary = "Branch operation";
@@ -115,8 +114,8 @@ def BranchOp : CF_Op<"br", [
def CondBranchOp
: CF_Op<"cond_br", [AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<BranchOpInterface,
- ["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
+ DeclareOpInterfaceMethods<
+ BranchOpInterface, ["getSuccessorForOperands"]>,
WeightedBranchOpInterface, Pure, Terminator]> {
let summary = "Conditional branch operation";
let description = [{
@@ -242,8 +241,7 @@ def CondBranchOp
def SwitchOp : CF_Op<"switch",
[AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<BranchOpInterface,
- ["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
+ DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
Pure, Terminator]> {
let summary = "Switch operation";
let description = [{
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index d32be0c63acc7..06fa724e05fab 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -98,15 +98,6 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
(ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
[{ return lhs == rhs; }]
>,
- InterfaceMethod<[{
- This method is called to returns the operands of this operation that
- are passed to the specified successor's block arguments. If the successor
- is not valid for this operation, or no operands are forwarded, an empty
- ValueRange is returned.
- }],
- "ValueRange", "getSuccessorForwardOperands",
- (ins "Block *":$successor), [{}],[{ return {};}]
- >,
];
let verify = [{
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index f6eb0f05911b8..435c37bc95aac 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -296,12 +296,6 @@ Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
return getDest();
}
-ValueRange BranchOp::getSuccessorForwardOperands(Block *successor) {
- if (successor == getDest())
- return getDestOperands();
- return {};
-}
-
//===----------------------------------------------------------------------===//
// CondBranchOp
//===----------------------------------------------------------------------===//
@@ -589,14 +583,6 @@ Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
return nullptr;
}
-ValueRange CondBranchOp::getSuccessorForwardOperands(Block *successor) {
- if (successor == getTrueDest())
- return getTrueOperands();
- else if (successor == getFalseDest())
- return getFalseOperands();
- return {};
-}
-
//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
@@ -1048,16 +1034,6 @@ void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
.add<SimplifyUniformBlockArguments>(context);
}
-ValueRange SwitchOp::getSuccessorForwardOperands(Block *successor) {
- if (successor == getDefaultDestination())
- return getDefaultOperands();
- SuccessorRange caseDests = getCaseDestinations();
- auto it = llvm::find(caseDests, successor);
- if (it == caseDests.end())
- return {};
- return getCaseOperands(std::distance(caseDests.begin(), it));
-}
-
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 12358f7d71688..7e4c7a5215b5c 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -251,11 +251,11 @@ static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
Operation *tergetTerminator = mapper.lookup(branchTerminator);
Block *selectedBlock = tergetTerminator->getSuccessor(i);
auto branchOp = cast<BranchOpInterface>(tergetTerminator);
- ValueRange selectedBlockOperands =
- branchOp.getSuccessorForwardOperands(selectedBlock);
+ mlir::SuccessorOperands selectedBlockOperands =
+ branchOp.getSuccessorOperands(i);
b.setInsertionPointAfter(tergetTerminator);
cf::BranchOp::create(b, tergetTerminator->getLoc(), selectedBlock,
- selectedBlockOperands);
+ selectedBlockOperands.getForwardedOperands());
tergetTerminator->erase();
// Apply canonicalization patterns to collapse the now-redundant branches
@@ -277,11 +277,11 @@ static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
if (branchIdx != -1) {
Block *selectedBlock = branchTerminator->getSuccessor(branchIdx);
auto branchOp = cast<BranchOpInterface>(branchTerminator);
- ValueRange selectedBlockOperands =
- branchOp.getSuccessorForwardOperands(selectedBlock);
+ mlir::SuccessorOperands selectedBlockOperands =
+ branchOp.getSuccessorOperands(branchIdx);
b.setInsertionPointAfter(branchTerminator);
cf::BranchOp::create(b, branchTerminator->getLoc(), selectedBlock,
- selectedBlockOperands);
+ selectedBlockOperands.getForwardedOperands());
branchTerminator->erase();
(void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
}
diff --git a/mlir/test/mlir-reduce/reduction-tree.mlir b/mlir/test/mlir-reduce/reduction-tree.mlir
index b053a111e9a16..6a329f6397156 100644
--- a/mlir/test/mlir-reduce/reduction-tree.mlir
+++ b/mlir/test/mlir-reduce/reduction-tree.mlir
@@ -106,20 +106,20 @@ func.func @br_reduction_loop(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf3
// CHECK-LABEL: func @switch_reduction
// CHECK-SAME: %[[ARG0:.*]]: i32,
// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
-// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) {
-func.func @switch_reduction(%arg0: i32, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+// CHECK-SAME: %[[ARG2:.*]]: memref<3xf32>)
+func.func @switch_reduction(%arg0: i32, %arg1: memref<2xf32>, %arg2: memref<3xf32>) {
cf.switch %arg0 : i32, [
default: ^bb3(%arg1 : memref<2xf32>),
- 0: ^bb1,
+ 0: ^bb1(%arg2: memref<3xf32>),
1: ^bb2
]
-^bb1:
+^bb1(%0: memref<3xf32>):
cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2:
- %0 = memref.alloc() : memref<2xf32>
- cf.br ^bb3(%0 : memref<2xf32>)
-^bb3(%1: memref<2xf32>):
- "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+ %1 = memref.alloc() : memref<2xf32>
+ cf.br ^bb3(%1 : memref<2xf32>)
+^bb3(%2: memref<2xf32>):
+ "test.op_crash"(%2, %arg2) : (memref<2xf32>, memref<3xf32>) -> ()
return
}
// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
>From d4930475f237b3933d83b143317934dd8684cb1b Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Thu, 30 Apr 2026 07:50:28 +0000
Subject: [PATCH 4/4] add comment.
---
mlir/lib/Reducer/ReductionTreePass.cpp | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 7e4c7a5215b5c..a1952909ecefc 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -300,7 +300,7 @@ template <typename IteratorType>
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
const FrozenRewritePatternSet &patterns,
const Tester &test) {
- // We separate the reduction process into 3 steps, the first one is to erase
+ // We separate the reduction process into 4 steps, the first one is to erase
// redundant operations and the second one is to apply the reducer patterns.
// In the first phase, we attempt to erase all operations within the entire
@@ -308,14 +308,16 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
if (succeeded(eraseAllOpsInRegion(module, region, test)))
return success();
+ // In the second phase, we attempt to eliminate redundant blocks. This reduces
+ // the program's execution paths.
(void)eraseRedundantBlocksInRegion(module, region, test);
- // In the second phase, we don't apply any patterns so that we only select the
+ // In the third phase, we don't apply any patterns so that we only select the
// range of operations to keep to the module stay interesting.
if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
/*eraseOpNotInRange=*/true)))
return failure();
- // In the third phase, we suppose that no operation is redundant, so we try
+ // In the fourth phase, we suppose that no operation is redundant, so we try
// to rewrite the operation into simpler form.
return findOptimal<IteratorType>(module, region, patterns, test,
/*eraseOpNotInRange=*/false);
More information about the Mlir-commits
mailing list