[Mlir-commits] [mlir] [mlir][Interfaces] Add `RegionBranchOpInterface::getSuccessorOperands` helper (PR #173971)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 30 01:55:39 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Add a helper for querying the successor operands for a region branch `src -> dst`. Both `src` and `dst` may be the region branch op itself or a terminator.

This helper allows users to query successor operands for the region branch op and the terminators in a uniform way. This is similar to `getSuccessorRegions(RegionBranchPoint)`, which works both for region branch ops and terminators.


---
Full diff: https://github.com/llvm/llvm-project/pull/173971.diff


3 Files Affected:

- (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.td (+9) 
- (modified) mlir/lib/Interfaces/ControlFlowInterfaces.cpp (+10) 
- (modified) mlir/lib/Transforms/RemoveDeadValues.cpp (+19-21) 


``````````diff
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 94242e3ba39ce..8760c8b8715f9 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -350,6 +350,15 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     /// Return `true` if there is a loop in the region branching graph. Only
     /// reachable regions (starting from the entry regions) are considered.
     bool hasLoop();
+
+    /// Return the successor operands from the source branch point to the
+    /// destination region successor.
+    ///
+    /// If the branch point is the parent op, this function returns entry
+    /// successor operands of this op. Otherwise, it returns successor operands
+    /// of the respective terminator.
+    ::mlir::OperandRange getSuccessorOperands(
+        ::mlir::RegionBranchPoint src, ::mlir::RegionSuccessor dest);
   }];
 }
 
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 7420412f09360..51861d7751450 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -479,6 +479,16 @@ bool RegionBranchOpInterface::hasLoop() {
   return false;
 }
 
+OperandRange
+RegionBranchOpInterface::getSuccessorOperands(RegionBranchPoint src,
+                                              RegionSuccessor dest) {
+  if (src.isParent())
+    return getEntrySuccessorOperands(dest);
+  auto terminator = cast<RegionBranchTerminatorOpInterface>(
+      src.getTerminatorPredecessorOrNull());
+  return terminator.getSuccessorOperands(dest);
+}
+
 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
   LDBG() << "Finding enclosing repetitive region for operation "
          << op->getName();
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 62ce5e0bbb77e..c7b9b49c9c159 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -475,13 +475,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
   // Return the operands of `terminator` that are forwarded to `successor` if
   // the former is not null. Else return the operands of `regionBranchOp`
   // forwarded to `successor`.
-  auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
-                                    Operation *terminator = nullptr) {
-    OperandRange operands =
-        terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
-                         .getSuccessorOperands(successor)
-                   : regionBranchOp.getEntrySuccessorOperands(successor);
-    SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
+  auto getForwardedOpOperands = [&](RegionBranchPoint src,
+                                    const RegionSuccessor &successor) {
+    SmallVector<OpOperand *> opOperands = operandsToOpOperands(
+        regionBranchOp.getSuccessorOperands(src, successor));
     return opOperands;
   };
 
@@ -491,7 +488,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
     nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
     for (const RegionSuccessor &successor :
          getSuccessors(RegionBranchPoint::parent())) {
-      for (OpOperand *opOperand : getForwardedOpOperands(successor))
+      for (OpOperand *opOperand :
+           getForwardedOpOperands(RegionBranchPoint::parent(), successor))
         nonForwardedOperands.reset(opOperand->getOperandNumber());
     }
   };
@@ -504,14 +502,13 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
           if (region.empty())
             continue;
           // TODO: this isn't correct in face of multiple terminators.
-          Operation *terminator = region.front().getTerminator();
+          auto terminator = cast<RegionBranchTerminatorOpInterface>(
+              region.front().getTerminator());
           nonForwardedRets[terminator] =
               BitVector(terminator->getNumOperands(), true);
-          for (const RegionSuccessor &successor :
-               getSuccessors(RegionBranchPoint(
-                   cast<RegionBranchTerminatorOpInterface>(terminator)))) {
-            for (OpOperand *opOperand :
-                 getForwardedOpOperands(successor, terminator))
+          for (const RegionSuccessor &successor : getSuccessors(terminator)) {
+            for (OpOperand *opOperand : getForwardedOpOperands(
+                     RegionBranchPoint(terminator), successor))
               nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
           }
         }
@@ -535,7 +532,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
         for (const RegionSuccessor &successor : getSuccessors(point)) {
           Region *successorRegion = successor.getSuccessor();
           for (auto [opOperand, input] :
-               llvm::zip(getForwardedOpOperands(successor, terminator),
+               llvm::zip(getForwardedOpOperands(point, successor),
                          successor.getSuccessorInputs())) {
             size_t operandNum = opOperand->getOperandNumber();
             bool updateBasedOn =
@@ -563,7 +560,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
              getSuccessors(RegionBranchPoint::parent())) {
           Region *successorRegion = successor.getSuccessor();
           for (auto [opOperand, input] :
-               llvm::zip(getForwardedOpOperands(successor),
+               llvm::zip(getForwardedOpOperands(RegionBranchPoint::parent(),
+                                                successor),
                          successor.getSuccessorInputs())) {
             bool recomputeBasedOn =
                 operandsToKeep[opOperand->getOperandNumber()];
@@ -593,13 +591,13 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
         for (Region &region : regionBranchOp->getRegions()) {
           if (region.empty())
             continue;
-          Operation *terminator = region.front().getTerminator();
-          for (const RegionSuccessor &successor :
-               getSuccessors(RegionBranchPoint(
-                   cast<RegionBranchTerminatorOpInterface>(terminator)))) {
+          auto terminator = cast<RegionBranchTerminatorOpInterface>(
+              region.front().getTerminator());
+          for (const RegionSuccessor &successor : getSuccessors(terminator)) {
             Region *successorRegion = successor.getSuccessor();
             for (auto [opOperand, input] :
-                 llvm::zip(getForwardedOpOperands(successor, terminator),
+                 llvm::zip(getForwardedOpOperands(RegionBranchPoint(terminator),
+                                                  successor),
                            successor.getSuccessorInputs())) {
               bool recomputeBasedOn =
                   terminatorOperandsToKeep[region.back().getTerminator()]

``````````

</details>


https://github.com/llvm/llvm-project/pull/173971


More information about the Mlir-commits mailing list