[Mlir-commits] [mlir] Reland "[mlir][reducer] Add eraseRedundantBlocksInRegion to reduction-tree pass" (PR #191961)

lonely eagle llvmlistbot at llvm.org
Mon May 11 07:02:30 PDT 2026


https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/191961

>From 16f33bb6dad2aeda0b532d01834132177bf11d4d Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Fri, 20 Mar 2026 16:48:40 +0000
Subject: [PATCH 1/2] add eraseRedundantBlocksInRegion.

add comment.
---
 mlir/include/mlir/Reducer/ReductionNode.h     |   3 +
 mlir/lib/Reducer/CMakeLists.txt               |   1 +
 mlir/lib/Reducer/ReductionNode.cpp            |  10 ++
 mlir/lib/Reducer/ReductionTreePass.cpp        | 122 +++++++++++++++++-
 .../reduction-tree/reduction-tree.mlir        |  65 ++++++++++
 5 files changed, 198 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h
index 6ca4e13d159ac..125a7c6f6f5e7 100644
--- a/mlir/include/mlir/Reducer/ReductionNode.h
+++ b/mlir/include/mlir/Reducer/ReductionNode.h
@@ -90,6 +90,9 @@ class ReductionNode {
   /// corresponding region.
   LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
 
+  LogicalResult initialize(ModuleOp parentModule, Region &parentRegion,
+                           IRMapping &mapper);
+
 private:
   /// A custom BFS iterator. The difference between
   /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.
diff --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt
index 68864e373c993..b18a4bca04fcb 100644
--- a/mlir/lib/Reducer/CMakeLists.txt
+++ b/mlir/lib/Reducer/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_library(MLIRReduce
    MLIRPass
    MLIRRewrite
    MLIRTransformUtils
+   MLIRControlFlowDialect
 
    DEPENDS
    MLIRReducerIncGen
diff --git a/mlir/lib/Reducer/ReductionNode.cpp b/mlir/lib/Reducer/ReductionNode.cpp
index 11aeaf77b4642..897aae0becf33 100644
--- a/mlir/lib/Reducer/ReductionNode.cpp
+++ b/mlir/lib/Reducer/ReductionNode.cpp
@@ -45,6 +45,16 @@ LogicalResult ReductionNode::initialize(ModuleOp parentModule,
   return success();
 }
 
+LogicalResult ReductionNode::initialize(ModuleOp parentModule,
+                                        Region &targetRegion,
+                                        IRMapping &mapper) {
+  module = cast<ModuleOp>(parentModule->clone(mapper));
+  // Use the first block of targetRegion to locate the cloned region.
+  Block *block = mapper.lookup(&*targetRegion.begin());
+  region = block->getParent();
+  return success();
+}
+
 /// If we haven't explored any variants from this node, we will create N
 /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
 /// max element in `ranges` and create 2 new variants for each call.
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 2244475e268fe..1f899080db51f 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -14,7 +14,10 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Reducer/Passes.h"
 #include "mlir/Reducer/ReductionNode.h"
 #include "mlir/Reducer/ReductionPatternInterface.h"
@@ -24,6 +27,9 @@
 
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/Allocator.h"
+#include "llvm/Support/DebugLog.h"
+
+#define DEBUG_TYPE "reduction-tree"
 
 namespace mlir {
 #define GEN_PASS_DEF_REDUCTIONTREEPASS
@@ -182,11 +188,117 @@ static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region &region,
   return failure();
 }
 
+// Returns the first branching terminator (cond_br, switch, etc.) found in the
+// region.
+static Operation *getBranchTerminatorInRegion(Region &region) {
+  for (Block &block : region.getBlocks()) {
+    if (block.getNumSuccessors() > 1)
+      return block.getTerminator();
+  }
+  return {};
+}
+
+/// Reduces the control flow in a region by iteratively forcing branching
+/// terminators to point to a single successor. It evaluates each potential
+/// branch path and commits the reduction that results in the smallest
+/// "interesting" module.
+static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
+                                                  Region &region,
+                                                  const Tester &test) {
+  std::pair<Tester::Interestingness, size_t> initStatus =
+      test.isInteresting(module);
+
+  // While exploring the reduction tree, we always branch from an interesting
+  // node. Thus the root node must be interesting.
+  if (initStatus.first != Tester::Interestingness::True)
+    return module.emitWarning() << "uninterested module will not be reduced";
+  llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
+
+  // We set the simplification level to Aggressive to enable block merging.
+  GreedyRewriteConfig config;
+  config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Aggressive);
+  config.setUseTopDownTraversal(true);
+
+  // Populate canonicalization patterns for cf ops. When all targets of a
+  // 'cf.cond_br' or 'cf.switch' point to the same block, they will be
+  // canonicalized into a 'cf.br'.
+  auto context = region.getContext();
+  RewritePatternSet patterns(context);
+  cf::BranchOp::getCanonicalizationPatterns(patterns, context);
+  cf::CondBranchOp::getCanonicalizationPatterns(patterns, context);
+  cf::SwitchOp::getCanonicalizationPatterns(patterns, context);
+  FrozenRewritePatternSet fPatterns = std::move(patterns);
+
+  ReductionNode *smallestNode = nullptr;
+  mlir::OpBuilder b(context);
+  while (Operation *branchTerminator = getBranchTerminatorInRegion(region)) {
+    size_t numSuccessor = branchTerminator->getNumSuccessors();
+    std::vector<ReductionNode::Range> ranges{
+        {0, std::distance(region.op_begin(), region.op_end())}};
+    // Iterate through each successor of the branching terminator to try
+    // reducing the control flow to a single-path execution.
+    int branchIdx = -1;
+    for (int i = 0, e = numSuccessor; i < e; ++i) {
+      // We allocate memory on the heap because the object will be assigned to
+      // 'smallestNode'.
+      ReductionNode *root = allocator.Allocate();
+      new (root) ReductionNode(nullptr, ranges, allocator);
+      mlir::IRMapping mapper;
+      if (failed(root->initialize(module, region, mapper)))
+        llvm_unreachable("unexpected initialization failure");
+      Operation *tergetTerminator = mapper.lookup(branchTerminator);
+      Block *selectedBlock = tergetTerminator->getSuccessor(i);
+      auto branchOp = cast<BranchOpInterface>(tergetTerminator);
+      mlir::SuccessorOperands selectedBlockOperands =
+          branchOp.getSuccessorOperands(i);
+      b.setInsertionPointAfter(tergetTerminator);
+      cf::BranchOp::create(b, tergetTerminator->getLoc(), selectedBlock,
+                           selectedBlockOperands.getForwardedOperands());
+      tergetTerminator->erase();
+
+      // Apply canonicalization patterns to collapse the now-redundant branches
+      (void)applyPatternsGreedily(root->getRegion().getParentOp(), fPatterns,
+                                  config);
+      root->update(test.isInteresting(root->getModule()));
+
+      // Track the smallest "interesting" version of the IR found so far.
+      if (root->isInteresting() == Tester::Interestingness::True &&
+          (smallestNode == nullptr ||
+           root->getSize() < smallestNode->getSize())) {
+        smallestNode = root;
+        branchIdx = i;
+      }
+    }
+
+    // If an interesting reduced branch was found, commit the change to the
+    // original region and re-apply patterns for a final cleanup.
+    if (branchIdx != -1) {
+      Block *selectedBlock = branchTerminator->getSuccessor(branchIdx);
+      auto branchOp = cast<BranchOpInterface>(branchTerminator);
+      mlir::SuccessorOperands selectedBlockOperands =
+          branchOp.getSuccessorOperands(branchIdx);
+      b.setInsertionPointAfter(branchTerminator);
+      cf::BranchOp::create(b, branchTerminator->getLoc(), selectedBlock,
+                           selectedBlockOperands.getForwardedOperands());
+      branchTerminator->erase();
+      (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
+    }
+  }
+
+  // If no branching terminators were found (skipping the while loop),
+  // there might still be opportunities for linear block merging or
+  // We apply patterns here as a final cleanup to ensure the region is fully
+  // simplified.
+  if (smallestNode == nullptr)
+    (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
+  return success();
+}
+
 template <typename IteratorType>
 static LogicalResult findOptimal(ModuleOp module, Region &region,
                                  const FrozenRewritePatternSet &patterns,
                                  const Tester &test) {
-  // We separate the reduction process into 3 steps, the first one is to erase
+  // We separate the reduction process into 4 steps, the first one is to erase
   // redundant operations and the second one is to apply the reducer patterns.
 
   // In the first phase, we attempt to erase all operations within the entire
@@ -194,12 +306,16 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
   if (succeeded(eraseAllOpsInRegion(module, region, test)))
     return success();
 
-  // In the second phase, we don't apply any patterns so that we only select the
+  // In the second phase, we attempt to eliminate redundant blocks. This reduces
+  // the program's execution paths.
+  (void)eraseRedundantBlocksInRegion(module, region, test);
+
+  // In the third phase, we don't apply any patterns so that we only select the
   // range of operations to keep to the module stay interesting.
   if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
                                        /*eraseOpNotInRange=*/true)))
     return failure();
-  // In the third phase, we suppose that no operation is redundant, so we try
+  // In the fourth phase, we suppose that no operation is redundant, so we try
   // to rewrite the operation into simpler form.
   return findOptimal<IteratorType>(module, region, patterns, test,
                                    /*eraseOpNotInRange=*/false);
diff --git a/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir b/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir
index b235ca14d693a..8e1a575676a8b 100644
--- a/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir
+++ b/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir
@@ -58,3 +58,68 @@ func.func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
 func.func @simple5() {
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @br_reduction
+//  CHECK-SAME:  %[[ARG0:.*]]: i1,
+//  CHECK-SAME:  %[[ARG1:.*]]: memref<2xf32>,
+//  CHECK-SAME:  %[[ARG2:.*]]: memref<2xf32>) {
+func.func @br_reduction(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+  cf.cond_br %arg0, ^bb1, ^bb2
+^bb1:
+  cf.br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+  %0 = memref.alloc() : memref<2xf32>
+  cf.br ^bb3(%0 : memref<2xf32>)
+^bb3(%1: memref<2xf32>):
+  "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+  return
+}
+// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
+
+// -----
+
+// CHECK-LABEL: func @br_reduction_loop
+//  CHECK-SAME:   %[[ARG0:.*]]: i1,
+//  CHECK-SAME:   %[[ARG1:.*]]: memref<2xf32>,
+//  CHECK-SAME:   %[[ARG2:.*]]: memref<2xf32>) {
+func.func @br_reduction_loop(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+  // select ^bb2
+  cf.cond_br %arg0, ^bb1, ^bb2
+^bb1:
+  cf.br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+  %0 = memref.alloc() : memref<2xf32>
+  cf.br ^bb3(%0 : memref<2xf32>)
+^bb3(%1: memref<2xf32>):
+  "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+  // select ^bb4
+  cf.cond_br %arg0, ^bb3(%1: memref<2xf32>), ^bb4
+^bb4:
+  return
+}
+// CHECK-NEXT:  "test.op_crash"(%[[ARG1]], %[[ARG2]])
+
+// -----
+
+// CHECK-LABEL: func @switch_reduction
+//  CHECK-SAME:   %[[ARG0:.*]]: i32,
+//  CHECK-SAME:   %[[ARG1:.*]]: memref<2xf32>,
+//  CHECK-SAME:   %[[ARG2:.*]]: memref<3xf32>)
+func.func @switch_reduction(%arg0: i32, %arg1: memref<2xf32>, %arg2: memref<3xf32>) {
+  cf.switch %arg0 : i32, [
+    default: ^bb3(%arg1 : memref<2xf32>),
+    0: ^bb1(%arg2: memref<3xf32>),
+    1: ^bb2
+  ]
+^bb1(%0: memref<3xf32>):
+  cf.br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+  %1 = memref.alloc() : memref<2xf32>
+  cf.br ^bb3(%1 : memref<2xf32>)
+^bb3(%2: memref<2xf32>):
+  "test.op_crash"(%2, %arg2) : (memref<2xf32>, memref<3xf32>) -> ()
+  return
+}
+// CHECK-NEXT:  "test.op_crash"(%[[ARG1]], %[[ARG2]])

>From a3dd7353e3ba11422c20ade8206ca5d5b7044546 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Mon, 11 May 2026 14:02:09 +0000
Subject: [PATCH 2/2] rebase main and update code.

---
 mlir/lib/Reducer/ReductionTreePass.cpp        | 174 +++++++++++++-----
 .../reduction-tree/reduction-tree.mlir        |  35 +++-
 2 files changed, 158 insertions(+), 51 deletions(-)

diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 1f899080db51f..a4cad5d19c725 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Reducer/Tester.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
 
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/Allocator.h"
@@ -188,50 +189,64 @@ static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region &region,
   return failure();
 }
 
-// Returns the first branching terminator (cond_br, switch, etc.) found in the
-// region.
-static Operation *getBranchTerminatorInRegion(Region &region) {
-  for (Block &block : region.getBlocks()) {
-    if (block.getNumSuccessors() > 1)
-      return block.getTerminator();
+/// Searches for an unvisited branch terminator within the given region based on
+/// the specified conditionality. This helper scans blocks in the \p region to
+/// find a terminator that has not yet been processed (not in \p visited). If
+/// \p isConditional is true, it looks for terminators with multiple successors
+/// (e.g., cf.cond_br). Otherwise, it looks for single-successor terminators
+/// (e.g., cf.br).
+static Operation *getBranchTerminatorInRegion(Region &region,
+                                              DenseSet<Operation *> &visited,
+                                              bool isConditional = true) {
+  auto it = llvm::find_if(region.getBlocks(), [&](Block &block) {
+    if (!block.mightHaveTerminator())
+      return false;
+    size_t numSucc = block.getNumSuccessors();
+    Operation *term = block.getTerminator();
+    return !visited.contains(term) &&
+           (isConditional ? numSucc > 1 : numSucc == 1);
+  });
+  return it != region.end() ? it->getTerminator() : nullptr;
+}
+
+/// Prunes unreachable blocks from the CFG using the \p worklist. This function
+/// iteratively removes blocks that have no predecessors. When a block is
+/// erased, its successors are added to the worklist as they may consequently
+/// become unreachable. This ensures a cascading deletion of dead-end paths in
+/// the control flow graph.
+static void pruneCFGEdges(SetVector<Block *> &workList, IRRewriter &rewriter) {
+  while (!workList.empty()) {
+    Block *b = workList.front();
+    workList.erase(workList.begin());
+    if (b->hasNoPredecessors()) {
+      for (Block *it : b->getSuccessors())
+        workList.insert(it);
+      rewriter.eraseBlock(b);
+    }
   }
-  return {};
 }
 
 /// Reduces the control flow in a region by iteratively forcing branching
 /// terminators to point to a single successor. It evaluates each potential
 /// branch path and commits the reduction that results in the smallest
 /// "interesting" module.
-static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
-                                                  Region &region,
-                                                  const Tester &test) {
+static LogicalResult reduceConditionalsInRegion(ModuleOp module, Region &region,
+                                                const Tester &test) {
   std::pair<Tester::Interestingness, size_t> initStatus =
       test.isInteresting(module);
 
-  // While exploring the reduction tree, we always branch from an interesting
-  // node. Thus the root node must be interesting.
   if (initStatus.first != Tester::Interestingness::True)
     return module.emitWarning() << "uninterested module will not be reduced";
   llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
 
-  // We set the simplification level to Aggressive to enable block merging.
-  GreedyRewriteConfig config;
-  config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Aggressive);
-  config.setUseTopDownTraversal(true);
-
-  // Populate canonicalization patterns for cf ops. When all targets of a
-  // 'cf.cond_br' or 'cf.switch' point to the same block, they will be
-  // canonicalized into a 'cf.br'.
-  auto context = region.getContext();
-  RewritePatternSet patterns(context);
-  cf::BranchOp::getCanonicalizationPatterns(patterns, context);
-  cf::CondBranchOp::getCanonicalizationPatterns(patterns, context);
-  cf::SwitchOp::getCanonicalizationPatterns(patterns, context);
-  FrozenRewritePatternSet fPatterns = std::move(patterns);
-
   ReductionNode *smallestNode = nullptr;
-  mlir::OpBuilder b(context);
-  while (Operation *branchTerminator = getBranchTerminatorInRegion(region)) {
+  mlir::IRRewriter rewriter(region.getContext());
+  DenseSet<Operation *> visited;
+
+  // This loop attempts to convert conditional branch operations into
+  // unconditional ones.
+  while (Operation *branchTerminator =
+             getBranchTerminatorInRegion(region, visited)) {
     size_t numSuccessor = branchTerminator->getNumSuccessors();
     std::vector<ReductionNode::Range> ranges{
         {0, std::distance(region.op_begin(), region.op_end())}};
@@ -246,22 +261,21 @@ static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
       mlir::IRMapping mapper;
       if (failed(root->initialize(module, region, mapper)))
         llvm_unreachable("unexpected initialization failure");
+
       Operation *tergetTerminator = mapper.lookup(branchTerminator);
       Block *selectedBlock = tergetTerminator->getSuccessor(i);
       auto branchOp = cast<BranchOpInterface>(tergetTerminator);
       mlir::SuccessorOperands selectedBlockOperands =
           branchOp.getSuccessorOperands(i);
-      b.setInsertionPointAfter(tergetTerminator);
-      cf::BranchOp::create(b, tergetTerminator->getLoc(), selectedBlock,
+      rewriter.setInsertionPointAfter(tergetTerminator);
+      cf::BranchOp::create(rewriter, tergetTerminator->getLoc(), selectedBlock,
                            selectedBlockOperands.getForwardedOperands());
-      tergetTerminator->erase();
-
-      // Apply canonicalization patterns to collapse the now-redundant branches
-      (void)applyPatternsGreedily(root->getRegion().getParentOp(), fPatterns,
-                                  config);
+      auto succs = llvm::to_vector(tergetTerminator->getSuccessors());
+      succs.erase(succs.begin() + i);
+      SetVector<Block *> workList(succs.begin(), succs.end());
+      rewriter.eraseOp(tergetTerminator);
+      pruneCFGEdges(workList, rewriter);
       root->update(test.isInteresting(root->getModule()));
-
-      // Track the smallest "interesting" version of the IR found so far.
       if (root->isInteresting() == Tester::Interestingness::True &&
           (smallestNode == nullptr ||
            root->getSize() < smallestNode->getSize())) {
@@ -270,27 +284,87 @@ static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
       }
     }
 
-    // If an interesting reduced branch was found, commit the change to the
-    // original region and re-apply patterns for a final cleanup.
     if (branchIdx != -1) {
       Block *selectedBlock = branchTerminator->getSuccessor(branchIdx);
       auto branchOp = cast<BranchOpInterface>(branchTerminator);
       mlir::SuccessorOperands selectedBlockOperands =
           branchOp.getSuccessorOperands(branchIdx);
-      b.setInsertionPointAfter(branchTerminator);
-      cf::BranchOp::create(b, branchTerminator->getLoc(), selectedBlock,
+      rewriter.setInsertionPointAfter(branchTerminator);
+      cf::BranchOp::create(rewriter, branchTerminator->getLoc(), selectedBlock,
                            selectedBlockOperands.getForwardedOperands());
-      branchTerminator->erase();
-      (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
+
+      auto succs = llvm::to_vector(branchOp->getSuccessors());
+      succs.erase(succs.begin() + branchIdx);
+      SetVector<Block *> workList(succs.begin(), succs.end());
+      rewriter.eraseOp(branchOp);
+      pruneCFGEdges(workList, rewriter);
+    } else {
+      // Insert 'branchTerminator' into visited to prevent it from being
+      // processed again.
+      visited.insert(branchTerminator);
     }
   }
+  return success();
+}
+
+/// Simplifies the Control Flow Graph (CFG) by merging blocks that have a
+/// single-successor / single-predecessor relationship. This function leverages
+/// the canonicalization patterns of 'cf.br' to perform the merge
+static LogicalResult reduceBlockMergeInRegion(ModuleOp module, Region &region,
+                                              const Tester &test) {
+  std::pair<Tester::Interestingness, size_t> initStatus =
+      test.isInteresting(module);
+
+  if (initStatus.first != Tester::Interestingness::True)
+    return module.emitWarning() << "uninterested module will not be reduced";
+  llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
+
+  GreedyRewriteConfig config;
+  auto context = region.getContext();
+  RewritePatternSet patterns(context);
+  cf::BranchOp::getCanonicalizationPatterns(patterns, context);
+  FrozenRewritePatternSet fPatterns = std::move(patterns);
+
+  mlir::IRRewriter rewriter(context);
+  DenseSet<Operation *> visited;
+  while (Operation *branchTerminator =
+             getBranchTerminatorInRegion(region, visited, false)) {
+    std::vector<ReductionNode::Range> ranges{
+        {0, std::distance(region.op_begin(), region.op_end())}};
+    ReductionNode *root = allocator.Allocate();
+    new (root) ReductionNode(nullptr, ranges, allocator);
+    mlir::IRMapping mapper;
+    if (failed(root->initialize(module, region, mapper)))
+      llvm_unreachable("unexpected initialization failure");
+    Operation *tergetTerminator = mapper.lookup(branchTerminator);
+    bool changed = false;
+    (void)applyOpPatternsGreedily(tergetTerminator, fPatterns, config,
+                                  &changed);
+    root->update(test.isInteresting(root->getModule()));
+
+    // If the changed variable is false, it indicates that the pattern failed to
+    // apply. We should insert it into visited to prevent it from being
+    // processed again.
+    if (changed && root->isInteresting() == Tester::Interestingness::True)
+      (void)applyOpPatternsGreedily(branchTerminator, fPatterns, config);
+    else
+      visited.insert(branchTerminator);
+  }
+  return success();
+}
+
+static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
+                                                  Region &region,
+                                                  const Tester &test) {
+  /// We separate the reduction control flow graph process into 2 steps.
+
+  // we attempts to simplify conditional branches into unconditional ones by
+  // picking the "interesting" path.
+  (void)reduceConditionalsInRegion(module, region, test);
 
-  // If no branching terminators were found (skipping the while loop),
-  // there might still be opportunities for linear block merging or
-  // We apply patterns here as a final cleanup to ensure the region is fully
-  // simplified.
-  if (smallestNode == nullptr)
-    (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
+  // We merge redundant blocks that have single-successor/single-predecessor
+  // relationships using canonicalization patterns.
+  (void)reduceBlockMergeInRegion(module, region, test);
   return success();
 }
 
diff --git a/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir b/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir
index 8e1a575676a8b..3fd7d13a10d13 100644
--- a/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir
+++ b/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir
@@ -99,7 +99,10 @@ func.func @br_reduction_loop(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf3
 ^bb4:
   return
 }
-// CHECK-NEXT:  "test.op_crash"(%[[ARG1]], %[[ARG2]])
+// CHECK:   cf.br ^bb1(%[[ARG1]] : memref<2xf32>)
+// CHECK: ^bb1(%[[VAL_0:.*]]: memref<2xf32>):
+// CHECK:   "test.op_crash"(%[[VAL_0]], %[[ARG2]])
+// CHECK:   cf.br ^bb1(%[[VAL_0]] : memref<2xf32>)
 
 // -----
 
@@ -123,3 +126,33 @@ func.func @switch_reduction(%arg0: i32, %arg1: memref<2xf32>, %arg2: memref<3xf3
   return
 }
 // CHECK-NEXT:  "test.op_crash"(%[[ARG1]], %[[ARG2]])
+
+// -----
+
+// This test verifies the ability to reduce unreachable code. 
+
+// CHECK-LABEL: func @unreachable_code
+// CHECK-SAME:      %[[ARG0:.*]]: i1,
+// CHECK-SAME:      %[[ARG1:.*]]: memref<2xf32>,
+// CHECK-SAME:      %[[ARG2:.*]]: memref<2xf32>)
+func.func @unreachable_code(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+  cf.br ^bb1
+^bb1:
+  cf.cond_br %arg0, ^bb2, ^bb3
+^bb2:
+  cf.br ^bb1
+^bb3:
+  %alloc = memref.alloc() : memref<2xf32>
+  cf.br ^bb1
+^bb4(%0: memref<2xf32>):
+  "test.op_crash"(%0, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+  cf.cond_br %arg0, ^bb4(%0 : memref<2xf32>), ^bb5
+^bb5:
+  return
+}
+// CHECK:   cf.br ^bb1
+// CHECK: ^bb1:
+// CHECK:   cf.br ^bb1
+// CHECK: ^bb2(%[[VAL_0:.*]]: memref<2xf32>):
+// CHECK:   "test.op_crash"(%[[VAL_0]], %[[ARG2]])
+// CHECK:    cf.br ^bb2(%[[VAL_0]] : memref<2xf32>)



More information about the Mlir-commits mailing list