[Mlir-commits] [mlir] [mlir][reducer] Add eraseRedundantBlocksInRegion and getSuccessorForwardOperands API to BranchOpInterface (PR #187864)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 21 06:16:07 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-cf

Author: lonely eagle (linuxlonelyeagle)

<details>
<summary>Changes</summary>

To simplify the output of the reduction-tree pass, this PR introduces the eraseRedundantBlocksInRegion. For regions containing multiple execution paths, this functionality selects the shortest 'interesting' path. Additionally, this PR adds the getSuccessorForwardOperands API to BranchOpInterface. This allows us to extract the ForwardOperands for a specific path chosen from multiple alternatives, enabling the creation of a cf.br operation for the redirected jump.

---
Full diff: https://github.com/llvm/llvm-project/pull/187864.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td (+6-4) 
- (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.td (+9) 
- (modified) mlir/include/mlir/Reducer/ReductionNode.h (+3) 
- (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+24) 
- (modified) mlir/lib/Reducer/ReductionNode.cpp (+10) 
- (modified) mlir/lib/Reducer/ReductionTreePass.cpp (+115) 
- (modified) mlir/test/mlir-reduce/reduction-tree.mlir (+66) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
index a441fd82546e3..ddea3a7eae590 100644
--- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
@@ -65,7 +65,8 @@ def AssertOp : CF_Op<"assert",
 //===----------------------------------------------------------------------===//
 
 def BranchOp : CF_Op<"br", [
-    DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
+    DeclareOpInterfaceMethods<BranchOpInterface,
+    ["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
     Pure, Terminator
   ]> {
   let summary = "Branch operation";
@@ -114,8 +115,8 @@ def BranchOp : CF_Op<"br", [
 
 def CondBranchOp
     : CF_Op<"cond_br", [AttrSizedOperandSegments,
-                        DeclareOpInterfaceMethods<
-                            BranchOpInterface, ["getSuccessorForOperands"]>,
+                        DeclareOpInterfaceMethods<BranchOpInterface,
+                        ["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
                         WeightedBranchOpInterface, Pure, Terminator]> {
   let summary = "Conditional branch operation";
   let description = [{
@@ -241,7 +242,8 @@ def CondBranchOp
 
 def SwitchOp : CF_Op<"switch",
     [AttrSizedOperandSegments,
-     DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
+     DeclareOpInterfaceMethods<BranchOpInterface,
+     ["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
      Pure, Terminator]> {
   let summary = "Switch operation";
   let description = [{
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 06fa724e05fab..d32be0c63acc7 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -98,6 +98,15 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
       (ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
        [{ return lhs == rhs; }]
     >,
+    InterfaceMethod<[{
+      This method is called to returns the operands of this operation that
+      are passed to the specified successor's block arguments. If the successor
+      is not valid for this operation, or no operands are forwarded, an empty
+      ValueRange is returned.
+      }],
+      "ValueRange", "getSuccessorForwardOperands",
+      (ins "Block *":$successor), [{}],[{ return {};}]
+    >,
   ];
 
   let verify = [{
diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h
index 6ca4e13d159ac..125a7c6f6f5e7 100644
--- a/mlir/include/mlir/Reducer/ReductionNode.h
+++ b/mlir/include/mlir/Reducer/ReductionNode.h
@@ -90,6 +90,9 @@ class ReductionNode {
   /// corresponding region.
   LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
 
+  LogicalResult initialize(ModuleOp parentModule, Region &parentRegion,
+                           IRMapping &mapper);
+
 private:
   /// A custom BFS iterator. The difference between
   /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 435c37bc95aac..f6eb0f05911b8 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -296,6 +296,12 @@ Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
   return getDest();
 }
 
+ValueRange BranchOp::getSuccessorForwardOperands(Block *successor) {
+  if (successor == getDest())
+    return getDestOperands();
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // CondBranchOp
 //===----------------------------------------------------------------------===//
@@ -583,6 +589,14 @@ Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
   return nullptr;
 }
 
+ValueRange CondBranchOp::getSuccessorForwardOperands(Block *successor) {
+  if (successor == getTrueDest())
+    return getTrueOperands();
+  else if (successor == getFalseDest())
+    return getFalseOperands();
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // SwitchOp
 //===----------------------------------------------------------------------===//
@@ -1034,6 +1048,16 @@ void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
       .add<SimplifyUniformBlockArguments>(context);
 }
 
+ValueRange SwitchOp::getSuccessorForwardOperands(Block *successor) {
+  if (successor == getDefaultDestination())
+    return getDefaultOperands();
+  SuccessorRange caseDests = getCaseDestinations();
+  auto it = llvm::find(caseDests, successor);
+  if (it == caseDests.end())
+    return {};
+  return getCaseOperands(std::distance(caseDests.begin(), it));
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Reducer/ReductionNode.cpp b/mlir/lib/Reducer/ReductionNode.cpp
index 11aeaf77b4642..897aae0becf33 100644
--- a/mlir/lib/Reducer/ReductionNode.cpp
+++ b/mlir/lib/Reducer/ReductionNode.cpp
@@ -45,6 +45,16 @@ LogicalResult ReductionNode::initialize(ModuleOp parentModule,
   return success();
 }
 
+LogicalResult ReductionNode::initialize(ModuleOp parentModule,
+                                        Region &targetRegion,
+                                        IRMapping &mapper) {
+  module = cast<ModuleOp>(parentModule->clone(mapper));
+  // Use the first block of targetRegion to locate the cloned region.
+  Block *block = mapper.lookup(&*targetRegion.begin());
+  region = block->getParent();
+  return success();
+}
+
 /// If we haven't explored any variants from this node, we will create N
 /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
 /// max element in `ranges` and create 2 new variants for each call.
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 83497143d9669..8a18c65fdacca 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -14,7 +14,10 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Reducer/Passes.h"
 #include "mlir/Reducer/ReductionNode.h"
 #include "mlir/Reducer/ReductionPatternInterface.h"
@@ -24,6 +27,9 @@
 
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/Allocator.h"
+#include "llvm/Support/DebugLog.h"
+
+#define DEBUG_TYPE "reduction-tree"
 
 namespace mlir {
 #define GEN_PASS_DEF_REDUCTIONTREEPASS
@@ -184,6 +190,113 @@ static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region &region,
   return failure();
 }
 
+// Returns the first branching terminator (cond_br, switch, etc.) found in the
+// region.
+static Operation *getBranchTerminatorInRegion(Region &region) {
+  for (Block &block : region.getBlocks()) {
+    if (block.getNumSuccessors() > 1)
+      return block.getTerminator();
+  }
+  return {};
+}
+
+/// Reduces the control flow in a region by iteratively forcing branching
+/// terminators to point to a single successor. It evaluates each potential
+/// branch path and commits the reduction that results in the smallest
+/// "interesting" module.
+static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
+                                                  Region &region,
+                                                  const Tester &test) {
+  std::pair<Tester::Interestingness, size_t> initStatus =
+      test.isInteresting(module);
+
+  // While exploring the reduction tree, we always branch from an interesting
+  // node. Thus the root node must be interesting.
+  if (initStatus.first != Tester::Interestingness::True)
+    return module.emitWarning() << "uninterested module will not be reduced";
+  llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
+
+  // We set the simplification level to Aggressive to enable block merging.
+  GreedyRewriteConfig config;
+  config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Aggressive);
+  config.setUseTopDownTraversal(true);
+
+  // Populate canonicalization patterns for cf ops. When all targets of a
+  // 'cf.cond_br' or 'cf.switch' point to the same block, they will be
+  // canonicalized into a 'cf.br'.
+  auto context = region.getContext();
+  RewritePatternSet patterns(context);
+  cf::BranchOp::getCanonicalizationPatterns(patterns, context);
+  cf::CondBranchOp::getCanonicalizationPatterns(patterns, context);
+  cf::SwitchOp::getCanonicalizationPatterns(patterns, context);
+  FrozenRewritePatternSet fPatterns = std::move(patterns);
+
+  ReductionNode *smallestNode = nullptr;
+  mlir::OpBuilder b(context);
+  while (Operation *branchTerminator = getBranchTerminatorInRegion(region)) {
+    size_t numSuccessor = branchTerminator->getNumSuccessors();
+    // We allocate memory on the heap because the object will be assigned to
+    // 'smallestNode'.
+    ReductionNode *root = allocator.Allocate();
+    std::vector<ReductionNode::Range> ranges{
+        {0, std::distance(region.op_begin(), region.op_end())}};
+
+    // Iterate through each successor of the branching terminator to try
+    // reducing the control flow to a single-path execution.
+    int branchIdx = -1;
+    for (int i = 0, e = numSuccessor; i < e; ++i) {
+      new (root) ReductionNode(nullptr, ranges, allocator);
+      mlir::IRMapping mapper;
+      if (failed(root->initialize(module, region, mapper)))
+        llvm_unreachable("unexpected initialization failure");
+      Operation *tergetTerminator = mapper.lookup(branchTerminator);
+      Block *selectedBlock = tergetTerminator->getSuccessor(i);
+      auto branchOp = cast<BranchOpInterface>(tergetTerminator);
+      ValueRange selectedBlockOperands =
+          branchOp.getSuccessorForwardOperands(selectedBlock);
+      b.setInsertionPointAfter(tergetTerminator);
+      cf::BranchOp::create(b, tergetTerminator->getLoc(), selectedBlock,
+                           selectedBlockOperands);
+      tergetTerminator->erase();
+
+      // Apply canonicalization patterns to collapse the now-redundant branches
+      (void)applyPatternsGreedily(root->getRegion().getParentOp(), fPatterns,
+                                  config);
+      root->update(test.isInteresting(root->getModule()));
+
+      // Track the smallest "interesting" version of the IR found so far.
+      if (root->isInteresting() == Tester::Interestingness::True &&
+          (smallestNode == nullptr ||
+           root->getSize() < smallestNode->getSize())) {
+        smallestNode = root;
+        branchIdx = i;
+      }
+    }
+
+    // If an interesting reduced branch was found, commit the change to the
+    // original region and re-apply patterns for a final cleanup.
+    if (branchIdx != -1) {
+      Block *selectedBlock = branchTerminator->getSuccessor(branchIdx);
+      auto branchOp = cast<BranchOpInterface>(branchTerminator);
+      ValueRange selectedBlockOperands =
+          branchOp.getSuccessorForwardOperands(selectedBlock);
+      b.setInsertionPointAfter(branchTerminator);
+      cf::BranchOp::create(b, branchTerminator->getLoc(), selectedBlock,
+                           selectedBlockOperands);
+      branchTerminator->erase();
+      (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
+    }
+  }
+
+  // If no branching terminators were found (skipping the while loop),
+  // there might still be opportunities for linear block merging or
+  // We apply patterns here as a final cleanup to ensure the region is fully
+  // simplified.
+  if (smallestNode == nullptr)
+    (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
+  return success();
+}
+
 template <typename IteratorType>
 static LogicalResult findOptimal(ModuleOp module, Region &region,
                                  const FrozenRewritePatternSet &patterns,
@@ -196,6 +309,8 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
   if (succeeded(eraseAllOpsInRegion(module, region, test)))
     return success();
 
+  (void)eraseRedundantBlocksInRegion(module, region, test);
+
   // In the second phase, we don't apply any patterns so that we only select the
   // range of operations to keep to the module stay interesting.
   if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
diff --git a/mlir/test/mlir-reduce/reduction-tree.mlir b/mlir/test/mlir-reduce/reduction-tree.mlir
index 2aee89741b42b..b693765fbed53 100644
--- a/mlir/test/mlir-reduce/reduction-tree.mlir
+++ b/mlir/test/mlir-reduce/reduction-tree.mlir
@@ -58,3 +58,69 @@ func.func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
 func.func @simple5() {
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @br_reduction
+//  CHECK-SAME:  %[[ARG0:.*]]: i1,
+//  CHECK-SAME:  %[[ARG1:.*]]: memref<2xf32>,
+//  CHECK-SAME:  %[[ARG2:.*]]: memref<2xf32>) {
+func.func @br_reduction(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+  cf.cond_br %arg0, ^bb1, ^bb2
+^bb1:
+  cf.br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+  %0 = memref.alloc() : memref<2xf32>
+  cf.br ^bb3(%0 : memref<2xf32>)
+^bb3(%1: memref<2xf32>):
+  "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+  return
+}
+// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
+
+// -----
+
+// CHECK-LABEL: func @br_reduction_loop
+//  CHECK-SAME:   %[[ARG0:.*]]: i1,
+//  CHECK-SAME:   %[[ARG1:.*]]: memref<2xf32>,
+//  CHECK-SAME:   %[[ARG2:.*]]: memref<2xf32>) {
+func.func @br_reduction_loop(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+  cf.cond_br %arg0, ^bb1, ^bb2
+^bb1:
+  cf.br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+  %0 = memref.alloc() : memref<2xf32>
+  cf.br ^bb3(%0 : memref<2xf32>)
+^bb3(%1: memref<2xf32>):
+  "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+  cf.cond_br %arg0, ^bb3(%1: memref<2xf32>), ^bb4
+^bb4:
+  return
+}
+// CHECK-NEXT:   cf.br ^bb1(%[[ARG1]] : memref<2xf32>)
+// CHECK-NEXT: ^bb1(%[[VAL_0:.*]]: memref<2xf32>):
+// CHECK-NEXT:   "test.op_crash"(%[[VAL_0]], %[[ARG2]])
+// CHECK-NEXT:   cf.br ^bb1(%[[VAL_0]] : memref<2xf32>)
+
+// -----
+
+// CHECK-LABEL: func @switch_reduction
+//  CHECK-SAME:   %[[ARG0:.*]]: i32,
+//  CHECK-SAME:   %[[ARG1:.*]]: memref<2xf32>,
+//  CHECK-SAME:   %[[ARG2:.*]]: memref<2xf32>) {
+func.func @switch_reduction(%arg0: i32, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+  cf.switch %arg0 : i32, [
+    default: ^bb3(%arg1 : memref<2xf32>),
+    0: ^bb1,
+    1: ^bb2
+  ]
+^bb1:
+  cf.br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+  %0 = memref.alloc() : memref<2xf32>
+  cf.br ^bb3(%0 : memref<2xf32>)
+^bb3(%1: memref<2xf32>):
+  "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+  return
+}
+// CHECK-NEXT:  "test.op_crash"(%[[ARG1]], %[[ARG2]])
\ No newline at end of file

``````````

</details>


https://github.com/llvm/llvm-project/pull/187864


More information about the Mlir-commits mailing list