[Mlir-commits] [mlir] Reland "[mlir][reducer] Add eraseRedundantBlocksInRegion to reduction-tree pass" (PR #191961)
lonely eagle
llvmlistbot at llvm.org
Mon May 11 07:02:30 PDT 2026
https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/191961
>From 16f33bb6dad2aeda0b532d01834132177bf11d4d Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Fri, 20 Mar 2026 16:48:40 +0000
Subject: [PATCH 1/2] add eraseRedundantBlocksInRegion.
add comment.
---
mlir/include/mlir/Reducer/ReductionNode.h | 3 +
mlir/lib/Reducer/CMakeLists.txt | 1 +
mlir/lib/Reducer/ReductionNode.cpp | 10 ++
mlir/lib/Reducer/ReductionTreePass.cpp | 122 +++++++++++++++++-
.../reduction-tree/reduction-tree.mlir | 65 ++++++++++
5 files changed, 198 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h
index 6ca4e13d159ac..125a7c6f6f5e7 100644
--- a/mlir/include/mlir/Reducer/ReductionNode.h
+++ b/mlir/include/mlir/Reducer/ReductionNode.h
@@ -90,6 +90,9 @@ class ReductionNode {
/// corresponding region.
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
+ LogicalResult initialize(ModuleOp parentModule, Region &parentRegion,
+ IRMapping &mapper);
+
private:
/// A custom BFS iterator. The difference between
/// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.
diff --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt
index 68864e373c993..b18a4bca04fcb 100644
--- a/mlir/lib/Reducer/CMakeLists.txt
+++ b/mlir/lib/Reducer/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_library(MLIRReduce
MLIRPass
MLIRRewrite
MLIRTransformUtils
+ MLIRControlFlowDialect
DEPENDS
MLIRReducerIncGen
diff --git a/mlir/lib/Reducer/ReductionNode.cpp b/mlir/lib/Reducer/ReductionNode.cpp
index 11aeaf77b4642..897aae0becf33 100644
--- a/mlir/lib/Reducer/ReductionNode.cpp
+++ b/mlir/lib/Reducer/ReductionNode.cpp
@@ -45,6 +45,16 @@ LogicalResult ReductionNode::initialize(ModuleOp parentModule,
return success();
}
+LogicalResult ReductionNode::initialize(ModuleOp parentModule,
+ Region &targetRegion,
+ IRMapping &mapper) {
+ module = cast<ModuleOp>(parentModule->clone(mapper));
+ // Use the first block of targetRegion to locate the cloned region.
+ Block *block = mapper.lookup(&*targetRegion.begin());
+ region = block->getParent();
+ return success();
+}
+
/// If we haven't explored any variants from this node, we will create N
/// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
/// max element in `ranges` and create 2 new variants for each call.
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 2244475e268fe..1f899080db51f 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -14,7 +14,10 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Reducer/Passes.h"
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
@@ -24,6 +27,9 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Allocator.h"
+#include "llvm/Support/DebugLog.h"
+
+#define DEBUG_TYPE "reduction-tree"
namespace mlir {
#define GEN_PASS_DEF_REDUCTIONTREEPASS
@@ -182,11 +188,117 @@ static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region ®ion,
return failure();
}
+// Returns the first branching terminator (cond_br, switch, etc.) found in the
+// region.
+static Operation *getBranchTerminatorInRegion(Region ®ion) {
+ for (Block &block : region.getBlocks()) {
+ if (block.getNumSuccessors() > 1)
+ return block.getTerminator();
+ }
+ return {};
+}
+
+/// Reduces the control flow in a region by iteratively forcing branching
+/// terminators to point to a single successor. It evaluates each potential
+/// branch path and commits the reduction that results in the smallest
+/// "interesting" module.
+static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
+ Region ®ion,
+ const Tester &test) {
+ std::pair<Tester::Interestingness, size_t> initStatus =
+ test.isInteresting(module);
+
+ // While exploring the reduction tree, we always branch from an interesting
+ // node. Thus the root node must be interesting.
+ if (initStatus.first != Tester::Interestingness::True)
+ return module.emitWarning() << "uninterested module will not be reduced";
+ llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
+
+ // We set the simplification level to Aggressive to enable block merging.
+ GreedyRewriteConfig config;
+ config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Aggressive);
+ config.setUseTopDownTraversal(true);
+
+ // Populate canonicalization patterns for cf ops. When all targets of a
+ // 'cf.cond_br' or 'cf.switch' point to the same block, they will be
+ // canonicalized into a 'cf.br'.
+ auto context = region.getContext();
+ RewritePatternSet patterns(context);
+ cf::BranchOp::getCanonicalizationPatterns(patterns, context);
+ cf::CondBranchOp::getCanonicalizationPatterns(patterns, context);
+ cf::SwitchOp::getCanonicalizationPatterns(patterns, context);
+ FrozenRewritePatternSet fPatterns = std::move(patterns);
+
+ ReductionNode *smallestNode = nullptr;
+ mlir::OpBuilder b(context);
+ while (Operation *branchTerminator = getBranchTerminatorInRegion(region)) {
+ size_t numSuccessor = branchTerminator->getNumSuccessors();
+ std::vector<ReductionNode::Range> ranges{
+ {0, std::distance(region.op_begin(), region.op_end())}};
+ // Iterate through each successor of the branching terminator to try
+ // reducing the control flow to a single-path execution.
+ int branchIdx = -1;
+ for (int i = 0, e = numSuccessor; i < e; ++i) {
+ // We allocate memory on the heap because the object will be assigned to
+ // 'smallestNode'.
+ ReductionNode *root = allocator.Allocate();
+ new (root) ReductionNode(nullptr, ranges, allocator);
+ mlir::IRMapping mapper;
+ if (failed(root->initialize(module, region, mapper)))
+ llvm_unreachable("unexpected initialization failure");
+ Operation *tergetTerminator = mapper.lookup(branchTerminator);
+ Block *selectedBlock = tergetTerminator->getSuccessor(i);
+ auto branchOp = cast<BranchOpInterface>(tergetTerminator);
+ mlir::SuccessorOperands selectedBlockOperands =
+ branchOp.getSuccessorOperands(i);
+ b.setInsertionPointAfter(tergetTerminator);
+ cf::BranchOp::create(b, tergetTerminator->getLoc(), selectedBlock,
+ selectedBlockOperands.getForwardedOperands());
+ tergetTerminator->erase();
+
+ // Apply canonicalization patterns to collapse the now-redundant branches
+ (void)applyPatternsGreedily(root->getRegion().getParentOp(), fPatterns,
+ config);
+ root->update(test.isInteresting(root->getModule()));
+
+ // Track the smallest "interesting" version of the IR found so far.
+ if (root->isInteresting() == Tester::Interestingness::True &&
+ (smallestNode == nullptr ||
+ root->getSize() < smallestNode->getSize())) {
+ smallestNode = root;
+ branchIdx = i;
+ }
+ }
+
+ // If an interesting reduced branch was found, commit the change to the
+ // original region and re-apply patterns for a final cleanup.
+ if (branchIdx != -1) {
+ Block *selectedBlock = branchTerminator->getSuccessor(branchIdx);
+ auto branchOp = cast<BranchOpInterface>(branchTerminator);
+ mlir::SuccessorOperands selectedBlockOperands =
+ branchOp.getSuccessorOperands(branchIdx);
+ b.setInsertionPointAfter(branchTerminator);
+ cf::BranchOp::create(b, branchTerminator->getLoc(), selectedBlock,
+ selectedBlockOperands.getForwardedOperands());
+ branchTerminator->erase();
+ (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
+ }
+ }
+
+ // If no branching terminators were found (skipping the while loop),
+ // there might still be opportunities for linear block merging or
+ // We apply patterns here as a final cleanup to ensure the region is fully
+ // simplified.
+ if (smallestNode == nullptr)
+ (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
+ return success();
+}
+
template <typename IteratorType>
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
const FrozenRewritePatternSet &patterns,
const Tester &test) {
- // We separate the reduction process into 3 steps, the first one is to erase
+ // We separate the reduction process into 4 steps, the first one is to erase
// redundant operations and the second one is to apply the reducer patterns.
// In the first phase, we attempt to erase all operations within the entire
@@ -194,12 +306,16 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
if (succeeded(eraseAllOpsInRegion(module, region, test)))
return success();
- // In the second phase, we don't apply any patterns so that we only select the
+ // In the second phase, we attempt to eliminate redundant blocks. This reduces
+ // the program's execution paths.
+ (void)eraseRedundantBlocksInRegion(module, region, test);
+
+ // In the third phase, we don't apply any patterns so that we only select the
// range of operations to keep to the module stay interesting.
if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
/*eraseOpNotInRange=*/true)))
return failure();
- // In the third phase, we suppose that no operation is redundant, so we try
+ // In the fourth phase, we suppose that no operation is redundant, so we try
// to rewrite the operation into simpler form.
return findOptimal<IteratorType>(module, region, patterns, test,
/*eraseOpNotInRange=*/false);
diff --git a/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir b/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir
index b235ca14d693a..8e1a575676a8b 100644
--- a/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir
+++ b/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir
@@ -58,3 +58,68 @@ func.func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func.func @simple5() {
return
}
+
+// -----
+
+// CHECK-LABEL: func @br_reduction
+// CHECK-SAME: %[[ARG0:.*]]: i1,
+// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
+// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) {
+func.func @br_reduction(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+ cf.cond_br %arg0, ^bb1, ^bb2
+^bb1:
+ cf.br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+ %0 = memref.alloc() : memref<2xf32>
+ cf.br ^bb3(%0 : memref<2xf32>)
+^bb3(%1: memref<2xf32>):
+ "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+ return
+}
+// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
+
+// -----
+
+// CHECK-LABEL: func @br_reduction_loop
+// CHECK-SAME: %[[ARG0:.*]]: i1,
+// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
+// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) {
+func.func @br_reduction_loop(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+ // select ^bb2
+ cf.cond_br %arg0, ^bb1, ^bb2
+^bb1:
+ cf.br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+ %0 = memref.alloc() : memref<2xf32>
+ cf.br ^bb3(%0 : memref<2xf32>)
+^bb3(%1: memref<2xf32>):
+ "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+ // select ^bb4
+ cf.cond_br %arg0, ^bb3(%1: memref<2xf32>), ^bb4
+^bb4:
+ return
+}
+// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
+
+// -----
+
+// CHECK-LABEL: func @switch_reduction
+// CHECK-SAME: %[[ARG0:.*]]: i32,
+// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
+// CHECK-SAME: %[[ARG2:.*]]: memref<3xf32>)
+func.func @switch_reduction(%arg0: i32, %arg1: memref<2xf32>, %arg2: memref<3xf32>) {
+ cf.switch %arg0 : i32, [
+ default: ^bb3(%arg1 : memref<2xf32>),
+ 0: ^bb1(%arg2: memref<3xf32>),
+ 1: ^bb2
+ ]
+^bb1(%0: memref<3xf32>):
+ cf.br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+ %1 = memref.alloc() : memref<2xf32>
+ cf.br ^bb3(%1 : memref<2xf32>)
+^bb3(%2: memref<2xf32>):
+ "test.op_crash"(%2, %arg2) : (memref<2xf32>, memref<3xf32>) -> ()
+ return
+}
+// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
>From a3dd7353e3ba11422c20ade8206ca5d5b7044546 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Mon, 11 May 2026 14:02:09 +0000
Subject: [PATCH 2/2] rebase main and update code.
---
mlir/lib/Reducer/ReductionTreePass.cpp | 174 +++++++++++++-----
.../reduction-tree/reduction-tree.mlir | 35 +++-
2 files changed, 158 insertions(+), 51 deletions(-)
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 1f899080db51f..a4cad5d19c725 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -24,6 +24,7 @@
#include "mlir/Reducer/Tester.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Allocator.h"
@@ -188,50 +189,64 @@ static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region ®ion,
return failure();
}
-// Returns the first branching terminator (cond_br, switch, etc.) found in the
-// region.
-static Operation *getBranchTerminatorInRegion(Region ®ion) {
- for (Block &block : region.getBlocks()) {
- if (block.getNumSuccessors() > 1)
- return block.getTerminator();
+/// Searches for an unvisited branch terminator within the given region based on
+/// the specified conditionality. This helper scans blocks in the \p region to
+/// find a terminator that has not yet been processed (not in \p visited). If
+/// \p isConditional is true, it looks for terminators with multiple successors
+/// (e.g., cf.cond_br). Otherwise, it looks for single-successor terminators
+/// (e.g., cf.br).
+static Operation *getBranchTerminatorInRegion(Region ®ion,
+ DenseSet<Operation *> &visited,
+ bool isConditional = true) {
+ auto it = llvm::find_if(region.getBlocks(), [&](Block &block) {
+ if (!block.mightHaveTerminator())
+ return false;
+ size_t numSucc = block.getNumSuccessors();
+ Operation *term = block.getTerminator();
+ return !visited.contains(term) &&
+ (isConditional ? numSucc > 1 : numSucc == 1);
+ });
+ return it != region.end() ? it->getTerminator() : nullptr;
+}
+
+/// Prunes unreachable blocks from the CFG using the \p worklist. This function
+/// iteratively removes blocks that have no predecessors. When a block is
+/// erased, its successors are added to the worklist as they may consequently
+/// become unreachable. This ensures a cascading deletion of dead-end paths in
+/// the control flow graph.
+static void pruneCFGEdges(SetVector<Block *> &workList, IRRewriter &rewriter) {
+ while (!workList.empty()) {
+ Block *b = workList.front();
+ workList.erase(workList.begin());
+ if (b->hasNoPredecessors()) {
+ for (Block *it : b->getSuccessors())
+ workList.insert(it);
+ rewriter.eraseBlock(b);
+ }
}
- return {};
}
/// Reduces the control flow in a region by iteratively forcing branching
/// terminators to point to a single successor. It evaluates each potential
/// branch path and commits the reduction that results in the smallest
/// "interesting" module.
-static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
- Region ®ion,
- const Tester &test) {
+static LogicalResult reduceConditionalsInRegion(ModuleOp module, Region ®ion,
+ const Tester &test) {
std::pair<Tester::Interestingness, size_t> initStatus =
test.isInteresting(module);
- // While exploring the reduction tree, we always branch from an interesting
- // node. Thus the root node must be interesting.
if (initStatus.first != Tester::Interestingness::True)
return module.emitWarning() << "uninterested module will not be reduced";
llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
- // We set the simplification level to Aggressive to enable block merging.
- GreedyRewriteConfig config;
- config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Aggressive);
- config.setUseTopDownTraversal(true);
-
- // Populate canonicalization patterns for cf ops. When all targets of a
- // 'cf.cond_br' or 'cf.switch' point to the same block, they will be
- // canonicalized into a 'cf.br'.
- auto context = region.getContext();
- RewritePatternSet patterns(context);
- cf::BranchOp::getCanonicalizationPatterns(patterns, context);
- cf::CondBranchOp::getCanonicalizationPatterns(patterns, context);
- cf::SwitchOp::getCanonicalizationPatterns(patterns, context);
- FrozenRewritePatternSet fPatterns = std::move(patterns);
-
ReductionNode *smallestNode = nullptr;
- mlir::OpBuilder b(context);
- while (Operation *branchTerminator = getBranchTerminatorInRegion(region)) {
+ mlir::IRRewriter rewriter(region.getContext());
+ DenseSet<Operation *> visited;
+
+ // This loop attempts to convert conditional branch operations into
+ // unconditional ones.
+ while (Operation *branchTerminator =
+ getBranchTerminatorInRegion(region, visited)) {
size_t numSuccessor = branchTerminator->getNumSuccessors();
std::vector<ReductionNode::Range> ranges{
{0, std::distance(region.op_begin(), region.op_end())}};
@@ -246,22 +261,21 @@ static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
mlir::IRMapping mapper;
if (failed(root->initialize(module, region, mapper)))
llvm_unreachable("unexpected initialization failure");
+
Operation *tergetTerminator = mapper.lookup(branchTerminator);
Block *selectedBlock = tergetTerminator->getSuccessor(i);
auto branchOp = cast<BranchOpInterface>(tergetTerminator);
mlir::SuccessorOperands selectedBlockOperands =
branchOp.getSuccessorOperands(i);
- b.setInsertionPointAfter(tergetTerminator);
- cf::BranchOp::create(b, tergetTerminator->getLoc(), selectedBlock,
+ rewriter.setInsertionPointAfter(tergetTerminator);
+ cf::BranchOp::create(rewriter, tergetTerminator->getLoc(), selectedBlock,
selectedBlockOperands.getForwardedOperands());
- tergetTerminator->erase();
-
- // Apply canonicalization patterns to collapse the now-redundant branches
- (void)applyPatternsGreedily(root->getRegion().getParentOp(), fPatterns,
- config);
+ auto succs = llvm::to_vector(tergetTerminator->getSuccessors());
+ succs.erase(succs.begin() + i);
+ SetVector<Block *> workList(succs.begin(), succs.end());
+ rewriter.eraseOp(tergetTerminator);
+ pruneCFGEdges(workList, rewriter);
root->update(test.isInteresting(root->getModule()));
-
- // Track the smallest "interesting" version of the IR found so far.
if (root->isInteresting() == Tester::Interestingness::True &&
(smallestNode == nullptr ||
root->getSize() < smallestNode->getSize())) {
@@ -270,27 +284,87 @@ static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
}
}
- // If an interesting reduced branch was found, commit the change to the
- // original region and re-apply patterns for a final cleanup.
if (branchIdx != -1) {
Block *selectedBlock = branchTerminator->getSuccessor(branchIdx);
auto branchOp = cast<BranchOpInterface>(branchTerminator);
mlir::SuccessorOperands selectedBlockOperands =
branchOp.getSuccessorOperands(branchIdx);
- b.setInsertionPointAfter(branchTerminator);
- cf::BranchOp::create(b, branchTerminator->getLoc(), selectedBlock,
+ rewriter.setInsertionPointAfter(branchTerminator);
+ cf::BranchOp::create(rewriter, branchTerminator->getLoc(), selectedBlock,
selectedBlockOperands.getForwardedOperands());
- branchTerminator->erase();
- (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
+
+ auto succs = llvm::to_vector(branchOp->getSuccessors());
+ succs.erase(succs.begin() + branchIdx);
+ SetVector<Block *> workList(succs.begin(), succs.end());
+ rewriter.eraseOp(branchOp);
+ pruneCFGEdges(workList, rewriter);
+ } else {
+ // Insert 'branchTerminator' into visited to prevent it from being
+ // processed again.
+ visited.insert(branchTerminator);
}
}
+ return success();
+}
+
+/// Simplifies the Control Flow Graph (CFG) by merging blocks that have a
+/// single-successor / single-predecessor relationship. This function leverages
+/// the canonicalization patterns of 'cf.br' to perform the merge
+static LogicalResult reduceBlockMergeInRegion(ModuleOp module, Region ®ion,
+ const Tester &test) {
+ std::pair<Tester::Interestingness, size_t> initStatus =
+ test.isInteresting(module);
+
+ if (initStatus.first != Tester::Interestingness::True)
+ return module.emitWarning() << "uninterested module will not be reduced";
+ llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
+
+ GreedyRewriteConfig config;
+ auto context = region.getContext();
+ RewritePatternSet patterns(context);
+ cf::BranchOp::getCanonicalizationPatterns(patterns, context);
+ FrozenRewritePatternSet fPatterns = std::move(patterns);
+
+ mlir::IRRewriter rewriter(context);
+ DenseSet<Operation *> visited;
+ while (Operation *branchTerminator =
+ getBranchTerminatorInRegion(region, visited, false)) {
+ std::vector<ReductionNode::Range> ranges{
+ {0, std::distance(region.op_begin(), region.op_end())}};
+ ReductionNode *root = allocator.Allocate();
+ new (root) ReductionNode(nullptr, ranges, allocator);
+ mlir::IRMapping mapper;
+ if (failed(root->initialize(module, region, mapper)))
+ llvm_unreachable("unexpected initialization failure");
+ Operation *tergetTerminator = mapper.lookup(branchTerminator);
+ bool changed = false;
+ (void)applyOpPatternsGreedily(tergetTerminator, fPatterns, config,
+ &changed);
+ root->update(test.isInteresting(root->getModule()));
+
+ // If the changed variable is false, it indicates that the pattern failed to
+ // apply. We should insert it into visited to prevent it from being
+ // processed again.
+ if (changed && root->isInteresting() == Tester::Interestingness::True)
+ (void)applyOpPatternsGreedily(branchTerminator, fPatterns, config);
+ else
+ visited.insert(branchTerminator);
+ }
+ return success();
+}
+
+static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
+ Region ®ion,
+ const Tester &test) {
+ /// We separate the reduction control flow graph process into 2 steps.
+
+ // we attempts to simplify conditional branches into unconditional ones by
+ // picking the "interesting" path.
+ (void)reduceConditionalsInRegion(module, region, test);
- // If no branching terminators were found (skipping the while loop),
- // there might still be opportunities for linear block merging or
- // We apply patterns here as a final cleanup to ensure the region is fully
- // simplified.
- if (smallestNode == nullptr)
- (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
+ // We merge redundant blocks that have single-successor/single-predecessor
+ // relationships using canonicalization patterns.
+ (void)reduceBlockMergeInRegion(module, region, test);
return success();
}
diff --git a/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir b/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir
index 8e1a575676a8b..3fd7d13a10d13 100644
--- a/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir
+++ b/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir
@@ -99,7 +99,10 @@ func.func @br_reduction_loop(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf3
^bb4:
return
}
-// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
+// CHECK: cf.br ^bb1(%[[ARG1]] : memref<2xf32>)
+// CHECK: ^bb1(%[[VAL_0:.*]]: memref<2xf32>):
+// CHECK: "test.op_crash"(%[[VAL_0]], %[[ARG2]])
+// CHECK: cf.br ^bb1(%[[VAL_0]] : memref<2xf32>)
// -----
@@ -123,3 +126,33 @@ func.func @switch_reduction(%arg0: i32, %arg1: memref<2xf32>, %arg2: memref<3xf3
return
}
// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
+
+// -----
+
+// This test verifies the ability to reduce unreachable code.
+
+// CHECK-LABEL: func @unreachable_code
+// CHECK-SAME: %[[ARG0:.*]]: i1,
+// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
+// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>)
+func.func @unreachable_code(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+ cf.br ^bb1
+^bb1:
+ cf.cond_br %arg0, ^bb2, ^bb3
+^bb2:
+ cf.br ^bb1
+^bb3:
+ %alloc = memref.alloc() : memref<2xf32>
+ cf.br ^bb1
+^bb4(%0: memref<2xf32>):
+ "test.op_crash"(%0, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+ cf.cond_br %arg0, ^bb4(%0 : memref<2xf32>), ^bb5
+^bb5:
+ return
+}
+// CHECK: cf.br ^bb1
+// CHECK: ^bb1:
+// CHECK: cf.br ^bb1
+// CHECK: ^bb2(%[[VAL_0:.*]]: memref<2xf32>):
+// CHECK: "test.op_crash"(%[[VAL_0]], %[[ARG2]])
+// CHECK: cf.br ^bb2(%[[VAL_0]] : memref<2xf32>)
More information about the Mlir-commits
mailing list