[Mlir-commits] [mlir] [mlir][Interfaces] Add `RegionBranchOpInterface::getSuccessorOperands` helper (PR #173971)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 30 01:55:39 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Add a helper for querying the successor operands for a region branch `src -> dst`. Both `src` and `dst` may be the region branch op itself or a terminator.
This helper allows users to query successor operands for the region branch op and the terminators in a uniform way. This is similar to `getSuccessorRegions(RegionBranchPoint)`, which works both for region branch ops and terminators.
---
Full diff: https://github.com/llvm/llvm-project/pull/173971.diff
3 Files Affected:
- (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.td (+9)
- (modified) mlir/lib/Interfaces/ControlFlowInterfaces.cpp (+10)
- (modified) mlir/lib/Transforms/RemoveDeadValues.cpp (+19-21)
``````````diff
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 94242e3ba39ce..8760c8b8715f9 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -350,6 +350,15 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
/// Return `true` if there is a loop in the region branching graph. Only
/// reachable regions (starting from the entry regions) are considered.
bool hasLoop();
+
+ /// Return the successor operands from the source branch point to the
+ /// destination region successor.
+ ///
+ /// If the branch point is the parent op, this function returns entry
+ /// successor operands of this op. Otherwise, it returns successor operands
+ /// of the respective terminator.
+ ::mlir::OperandRange getSuccessorOperands(
+ ::mlir::RegionBranchPoint src, ::mlir::RegionSuccessor dest);
}];
}
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 7420412f09360..51861d7751450 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -479,6 +479,16 @@ bool RegionBranchOpInterface::hasLoop() {
return false;
}
+OperandRange
+RegionBranchOpInterface::getSuccessorOperands(RegionBranchPoint src,
+ RegionSuccessor dest) {
+ if (src.isParent())
+ return getEntrySuccessorOperands(dest);
+ auto terminator = cast<RegionBranchTerminatorOpInterface>(
+ src.getTerminatorPredecessorOrNull());
+ return terminator.getSuccessorOperands(dest);
+}
+
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
LDBG() << "Finding enclosing repetitive region for operation "
<< op->getName();
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 62ce5e0bbb77e..c7b9b49c9c159 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -475,13 +475,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// Return the operands of `terminator` that are forwarded to `successor` if
// the former is not null. Else return the operands of `regionBranchOp`
// forwarded to `successor`.
- auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
- Operation *terminator = nullptr) {
- OperandRange operands =
- terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
- .getSuccessorOperands(successor)
- : regionBranchOp.getEntrySuccessorOperands(successor);
- SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
+ auto getForwardedOpOperands = [&](RegionBranchPoint src,
+ const RegionSuccessor &successor) {
+ SmallVector<OpOperand *> opOperands = operandsToOpOperands(
+ regionBranchOp.getSuccessorOperands(src, successor));
return opOperands;
};
@@ -491,7 +488,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
for (const RegionSuccessor &successor :
getSuccessors(RegionBranchPoint::parent())) {
- for (OpOperand *opOperand : getForwardedOpOperands(successor))
+ for (OpOperand *opOperand :
+ getForwardedOpOperands(RegionBranchPoint::parent(), successor))
nonForwardedOperands.reset(opOperand->getOperandNumber());
}
};
@@ -504,14 +502,13 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
// TODO: this isn't correct in face of multiple terminators.
- Operation *terminator = region.front().getTerminator();
+ auto terminator = cast<RegionBranchTerminatorOpInterface>(
+ region.front().getTerminator());
nonForwardedRets[terminator] =
BitVector(terminator->getNumOperands(), true);
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator)))) {
- for (OpOperand *opOperand :
- getForwardedOpOperands(successor, terminator))
+ for (const RegionSuccessor &successor : getSuccessors(terminator)) {
+ for (OpOperand *opOperand : getForwardedOpOperands(
+ RegionBranchPoint(terminator), successor))
nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
}
}
@@ -535,7 +532,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
for (const RegionSuccessor &successor : getSuccessors(point)) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(successor, terminator),
+ llvm::zip(getForwardedOpOperands(point, successor),
successor.getSuccessorInputs())) {
size_t operandNum = opOperand->getOperandNumber();
bool updateBasedOn =
@@ -563,7 +560,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
getSuccessors(RegionBranchPoint::parent())) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(successor),
+ llvm::zip(getForwardedOpOperands(RegionBranchPoint::parent(),
+ successor),
successor.getSuccessorInputs())) {
bool recomputeBasedOn =
operandsToKeep[opOperand->getOperandNumber()];
@@ -593,13 +591,13 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
for (Region ®ion : regionBranchOp->getRegions()) {
if (region.empty())
continue;
- Operation *terminator = region.front().getTerminator();
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator)))) {
+ auto terminator = cast<RegionBranchTerminatorOpInterface>(
+ region.front().getTerminator());
+ for (const RegionSuccessor &successor : getSuccessors(terminator)) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(successor, terminator),
+ llvm::zip(getForwardedOpOperands(RegionBranchPoint(terminator),
+ successor),
successor.getSuccessorInputs())) {
bool recomputeBasedOn =
terminatorOperandsToKeep[region.back().getTerminator()]
``````````
</details>
https://github.com/llvm/llvm-project/pull/173971
More information about the Mlir-commits
mailing list