[Mlir-commits] [mlir] ce1a9fd - Reland "[mlir][reducer] Add eraseRedundantBlocksInRegion and getSuccessorForwardOperands API to BranchOpInterface" (#189253)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Apr 5 21:54:20 PDT 2026


Author: lonely eagle
Date: 2026-04-06T12:54:15+08:00
New Revision: ce1a9fd76640929fe340c5c5d1bb493ea09ca9bc

URL: https://github.com/llvm/llvm-project/commit/ce1a9fd76640929fe340c5c5d1bb493ea09ca9bc
DIFF: https://github.com/llvm/llvm-project/commit/ce1a9fd76640929fe340c5c5d1bb493ea09ca9bc.diff

LOG: Reland "[mlir][reducer] Add eraseRedundantBlocksInRegion and getSuccessorForwardOperands API to BranchOpInterface" (#189253)

After fixing undefined symbol and memory leak issues(You can see
previous issue https://github.com/llvm/llvm-project/pull/189150), the PR
would like to reland
it(https://github.com/llvm/llvm-project/pull/187864).

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
    mlir/include/mlir/Reducer/ReductionNode.h
    mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
    mlir/lib/Reducer/CMakeLists.txt
    mlir/lib/Reducer/ReductionNode.cpp
    mlir/lib/Reducer/ReductionTreePass.cpp
    mlir/test/mlir-reduce/reduction-tree.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
index a441fd82546e3..ddea3a7eae590 100644
--- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
@@ -65,7 +65,8 @@ def AssertOp : CF_Op<"assert",
 //===----------------------------------------------------------------------===//
 
 def BranchOp : CF_Op<"br", [
-    DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
+    DeclareOpInterfaceMethods<BranchOpInterface,
+    ["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
     Pure, Terminator
   ]> {
   let summary = "Branch operation";
@@ -114,8 +115,8 @@ def BranchOp : CF_Op<"br", [
 
 def CondBranchOp
     : CF_Op<"cond_br", [AttrSizedOperandSegments,
-                        DeclareOpInterfaceMethods<
-                            BranchOpInterface, ["getSuccessorForOperands"]>,
+                        DeclareOpInterfaceMethods<BranchOpInterface,
+                        ["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
                         WeightedBranchOpInterface, Pure, Terminator]> {
   let summary = "Conditional branch operation";
   let description = [{
@@ -241,7 +242,8 @@ def CondBranchOp
 
 def SwitchOp : CF_Op<"switch",
     [AttrSizedOperandSegments,
-     DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
+     DeclareOpInterfaceMethods<BranchOpInterface,
+     ["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
      Pure, Terminator]> {
   let summary = "Switch operation";
   let description = [{

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 06fa724e05fab..d32be0c63acc7 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -98,6 +98,15 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
       (ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
        [{ return lhs == rhs; }]
     >,
+    InterfaceMethod<[{
+      This method is called to returns the operands of this operation that
+      are passed to the specified successor's block arguments. If the successor
+      is not valid for this operation, or no operands are forwarded, an empty
+      ValueRange is returned.
+      }],
+      "ValueRange", "getSuccessorForwardOperands",
+      (ins "Block *":$successor), [{}],[{ return {};}]
+    >,
   ];
 
   let verify = [{

diff  --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h
index 6ca4e13d159ac..125a7c6f6f5e7 100644
--- a/mlir/include/mlir/Reducer/ReductionNode.h
+++ b/mlir/include/mlir/Reducer/ReductionNode.h
@@ -90,6 +90,9 @@ class ReductionNode {
   /// corresponding region.
   LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
 
+  LogicalResult initialize(ModuleOp parentModule, Region &parentRegion,
+                           IRMapping &mapper);
+
 private:
   /// A custom BFS iterator. The 
diff erence between
   /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.

diff  --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 435c37bc95aac..f6eb0f05911b8 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -296,6 +296,12 @@ Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
   return getDest();
 }
 
+ValueRange BranchOp::getSuccessorForwardOperands(Block *successor) {
+  if (successor == getDest())
+    return getDestOperands();
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // CondBranchOp
 //===----------------------------------------------------------------------===//
@@ -583,6 +589,14 @@ Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
   return nullptr;
 }
 
+ValueRange CondBranchOp::getSuccessorForwardOperands(Block *successor) {
+  if (successor == getTrueDest())
+    return getTrueOperands();
+  else if (successor == getFalseDest())
+    return getFalseOperands();
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // SwitchOp
 //===----------------------------------------------------------------------===//
@@ -1034,6 +1048,16 @@ void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
       .add<SimplifyUniformBlockArguments>(context);
 }
 
+ValueRange SwitchOp::getSuccessorForwardOperands(Block *successor) {
+  if (successor == getDefaultDestination())
+    return getDefaultOperands();
+  SuccessorRange caseDests = getCaseDestinations();
+  auto it = llvm::find(caseDests, successor);
+  if (it == caseDests.end())
+    return {};
+  return getCaseOperands(std::distance(caseDests.begin(), it));
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt
index 68864e373c993..b18a4bca04fcb 100644
--- a/mlir/lib/Reducer/CMakeLists.txt
+++ b/mlir/lib/Reducer/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_library(MLIRReduce
    MLIRPass
    MLIRRewrite
    MLIRTransformUtils
+   MLIRControlFlowDialect
 
    DEPENDS
    MLIRReducerIncGen

diff  --git a/mlir/lib/Reducer/ReductionNode.cpp b/mlir/lib/Reducer/ReductionNode.cpp
index 11aeaf77b4642..897aae0becf33 100644
--- a/mlir/lib/Reducer/ReductionNode.cpp
+++ b/mlir/lib/Reducer/ReductionNode.cpp
@@ -45,6 +45,16 @@ LogicalResult ReductionNode::initialize(ModuleOp parentModule,
   return success();
 }
 
+LogicalResult ReductionNode::initialize(ModuleOp parentModule,
+                                        Region &targetRegion,
+                                        IRMapping &mapper) {
+  module = cast<ModuleOp>(parentModule->clone(mapper));
+  // Use the first block of targetRegion to locate the cloned region.
+  Block *block = mapper.lookup(&*targetRegion.begin());
+  region = block->getParent();
+  return success();
+}
+
 /// If we haven't explored any variants from this node, we will create N
 /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
 /// max element in `ranges` and create 2 new variants for each call.

diff  --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 83497143d9669..12358f7d71688 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -14,7 +14,10 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Reducer/Passes.h"
 #include "mlir/Reducer/ReductionNode.h"
 #include "mlir/Reducer/ReductionPatternInterface.h"
@@ -24,6 +27,9 @@
 
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/Allocator.h"
+#include "llvm/Support/DebugLog.h"
+
+#define DEBUG_TYPE "reduction-tree"
 
 namespace mlir {
 #define GEN_PASS_DEF_REDUCTIONTREEPASS
@@ -184,6 +190,112 @@ static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region &region,
   return failure();
 }
 
+// Returns the first branching terminator (cond_br, switch, etc.) found in the
+// region.
+static Operation *getBranchTerminatorInRegion(Region &region) {
+  for (Block &block : region.getBlocks()) {
+    if (block.getNumSuccessors() > 1)
+      return block.getTerminator();
+  }
+  return {};
+}
+
+/// Reduces the control flow in a region by iteratively forcing branching
+/// terminators to point to a single successor. It evaluates each potential
+/// branch path and commits the reduction that results in the smallest
+/// "interesting" module.
+static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
+                                                  Region &region,
+                                                  const Tester &test) {
+  std::pair<Tester::Interestingness, size_t> initStatus =
+      test.isInteresting(module);
+
+  // While exploring the reduction tree, we always branch from an interesting
+  // node. Thus the root node must be interesting.
+  if (initStatus.first != Tester::Interestingness::True)
+    return module.emitWarning() << "uninterested module will not be reduced";
+  llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
+
+  // We set the simplification level to Aggressive to enable block merging.
+  GreedyRewriteConfig config;
+  config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Aggressive);
+  config.setUseTopDownTraversal(true);
+
+  // Populate canonicalization patterns for cf ops. When all targets of a
+  // 'cf.cond_br' or 'cf.switch' point to the same block, they will be
+  // canonicalized into a 'cf.br'.
+  auto context = region.getContext();
+  RewritePatternSet patterns(context);
+  cf::BranchOp::getCanonicalizationPatterns(patterns, context);
+  cf::CondBranchOp::getCanonicalizationPatterns(patterns, context);
+  cf::SwitchOp::getCanonicalizationPatterns(patterns, context);
+  FrozenRewritePatternSet fPatterns = std::move(patterns);
+
+  ReductionNode *smallestNode = nullptr;
+  mlir::OpBuilder b(context);
+  while (Operation *branchTerminator = getBranchTerminatorInRegion(region)) {
+    size_t numSuccessor = branchTerminator->getNumSuccessors();
+    std::vector<ReductionNode::Range> ranges{
+        {0, std::distance(region.op_begin(), region.op_end())}};
+    // Iterate through each successor of the branching terminator to try
+    // reducing the control flow to a single-path execution.
+    int branchIdx = -1;
+    for (int i = 0, e = numSuccessor; i < e; ++i) {
+      // We allocate memory on the heap because the object will be assigned to
+      // 'smallestNode'.
+      ReductionNode *root = allocator.Allocate();
+      new (root) ReductionNode(nullptr, ranges, allocator);
+      mlir::IRMapping mapper;
+      if (failed(root->initialize(module, region, mapper)))
+        llvm_unreachable("unexpected initialization failure");
+      Operation *tergetTerminator = mapper.lookup(branchTerminator);
+      Block *selectedBlock = tergetTerminator->getSuccessor(i);
+      auto branchOp = cast<BranchOpInterface>(tergetTerminator);
+      ValueRange selectedBlockOperands =
+          branchOp.getSuccessorForwardOperands(selectedBlock);
+      b.setInsertionPointAfter(tergetTerminator);
+      cf::BranchOp::create(b, tergetTerminator->getLoc(), selectedBlock,
+                           selectedBlockOperands);
+      tergetTerminator->erase();
+
+      // Apply canonicalization patterns to collapse the now-redundant branches
+      (void)applyPatternsGreedily(root->getRegion().getParentOp(), fPatterns,
+                                  config);
+      root->update(test.isInteresting(root->getModule()));
+
+      // Track the smallest "interesting" version of the IR found so far.
+      if (root->isInteresting() == Tester::Interestingness::True &&
+          (smallestNode == nullptr ||
+           root->getSize() < smallestNode->getSize())) {
+        smallestNode = root;
+        branchIdx = i;
+      }
+    }
+
+    // If an interesting reduced branch was found, commit the change to the
+    // original region and re-apply patterns for a final cleanup.
+    if (branchIdx != -1) {
+      Block *selectedBlock = branchTerminator->getSuccessor(branchIdx);
+      auto branchOp = cast<BranchOpInterface>(branchTerminator);
+      ValueRange selectedBlockOperands =
+          branchOp.getSuccessorForwardOperands(selectedBlock);
+      b.setInsertionPointAfter(branchTerminator);
+      cf::BranchOp::create(b, branchTerminator->getLoc(), selectedBlock,
+                           selectedBlockOperands);
+      branchTerminator->erase();
+      (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
+    }
+  }
+
+  // If no branching terminators were found (skipping the while loop),
+  // there might still be opportunities for linear block merging or
+  // We apply patterns here as a final cleanup to ensure the region is fully
+  // simplified.
+  if (smallestNode == nullptr)
+    (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
+  return success();
+}
+
 template <typename IteratorType>
 static LogicalResult findOptimal(ModuleOp module, Region &region,
                                  const FrozenRewritePatternSet &patterns,
@@ -196,6 +308,8 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
   if (succeeded(eraseAllOpsInRegion(module, region, test)))
     return success();
 
+  (void)eraseRedundantBlocksInRegion(module, region, test);
+
   // In the second phase, we don't apply any patterns so that we only select the
   // range of operations to keep to the module stay interesting.
   if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,

diff  --git a/mlir/test/mlir-reduce/reduction-tree.mlir b/mlir/test/mlir-reduce/reduction-tree.mlir
index 2aee89741b42b..b053a111e9a16 100644
--- a/mlir/test/mlir-reduce/reduction-tree.mlir
+++ b/mlir/test/mlir-reduce/reduction-tree.mlir
@@ -58,3 +58,68 @@ func.func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
 func.func @simple5() {
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @br_reduction
+//  CHECK-SAME:  %[[ARG0:.*]]: i1,
+//  CHECK-SAME:  %[[ARG1:.*]]: memref<2xf32>,
+//  CHECK-SAME:  %[[ARG2:.*]]: memref<2xf32>) {
+func.func @br_reduction(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+  cf.cond_br %arg0, ^bb1, ^bb2
+^bb1:
+  cf.br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+  %0 = memref.alloc() : memref<2xf32>
+  cf.br ^bb3(%0 : memref<2xf32>)
+^bb3(%1: memref<2xf32>):
+  "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+  return
+}
+// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
+
+// -----
+
+// CHECK-LABEL: func @br_reduction_loop
+//  CHECK-SAME:   %[[ARG0:.*]]: i1,
+//  CHECK-SAME:   %[[ARG1:.*]]: memref<2xf32>,
+//  CHECK-SAME:   %[[ARG2:.*]]: memref<2xf32>) {
+func.func @br_reduction_loop(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+  // select ^bb2
+  cf.cond_br %arg0, ^bb1, ^bb2
+^bb1:
+  cf.br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+  %0 = memref.alloc() : memref<2xf32>
+  cf.br ^bb3(%0 : memref<2xf32>)
+^bb3(%1: memref<2xf32>):
+  "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+  // select ^bb4
+  cf.cond_br %arg0, ^bb3(%1: memref<2xf32>), ^bb4
+^bb4:
+  return
+}
+// CHECK-NEXT:  "test.op_crash"(%[[ARG1]], %[[ARG2]])
+
+// -----
+
+// CHECK-LABEL: func @switch_reduction
+//  CHECK-SAME:   %[[ARG0:.*]]: i32,
+//  CHECK-SAME:   %[[ARG1:.*]]: memref<2xf32>,
+//  CHECK-SAME:   %[[ARG2:.*]]: memref<2xf32>) {
+func.func @switch_reduction(%arg0: i32, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+  cf.switch %arg0 : i32, [
+    default: ^bb3(%arg1 : memref<2xf32>),
+    0: ^bb1,
+    1: ^bb2
+  ]
+^bb1:
+  cf.br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+  %0 = memref.alloc() : memref<2xf32>
+  cf.br ^bb3(%0 : memref<2xf32>)
+^bb3(%1: memref<2xf32>):
+  "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+  return
+}
+// CHECK-NEXT:  "test.op_crash"(%[[ARG1]], %[[ARG2]])


        


More information about the Mlir-commits mailing list