[Mlir-commits] [mlir] [mlir][Interfaces][NFC] Improve return type of `getTerminatorPredecessorOrNull` (PR #176714)
Matthias Springer
llvmlistbot at llvm.org
Mon Jan 19 01:29:05 PST 2026
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/176714
The terminator is always a `RegionBranchTerminatorOpInterface` (or "null"). There is no other way to construct a `RegionBranchPoint`.
Note: `RegionBranchPoint::predecessor` is still a `Operation *` due to layering constraints. Storing a `RegionBranchTerminatorOpInterface` would require a full definition of `RegionBranchTerminatorOpInterface`, but `RegionBranchTerminatorOpInterface` cannot be defined before `RegionBranchPoint` because it has default interface implementations that construct a `RegionBranchPoint`.
>From 9332e66492234b3447856c3bed02b7a65e2ea278 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 19 Jan 2026 09:24:21 +0000
Subject: [PATCH] [mlir][Interfaces][NFC] Improve return type of
`getTerminatorPredecessorOrNull`
---
.../mlir/Interfaces/ControlFlowInterfaces.h | 64 +++++++++++--------
.../mlir/Interfaces/ControlFlowInterfaces.td | 2 +-
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 2 -
mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 4 +-
4 files changed, 38 insertions(+), 34 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index a3382e15fb76d..d764089f5ccc8 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -252,8 +252,9 @@ class RegionBranchPoint {
bool isParent() const { return predecessor == nullptr; }
/// Returns the terminator if branching from a region.
- /// A null pointer otherwise.
- Operation *getTerminatorPredecessorOrNull() const { return predecessor; }
+ /// A "null" operation otherwise.
+ inline RegionBranchTerminatorOpInterface
+ getTerminatorPredecessorOrNull() const;
/// Returns true if the two branch points are equal.
friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
@@ -269,32 +270,6 @@ class RegionBranchPoint {
Operation *predecessor = nullptr;
};
-inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
- return !(lhs == rhs);
-}
-
-inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
- RegionBranchPoint point) {
- if (point.isParent())
- return os << "<from parent>";
- return os << "<region #"
- << point.getTerminatorPredecessorOrNull()
- ->getParentRegion()
- ->getRegionNumber()
- << ", terminator "
- << OpWithFlags(point.getTerminatorPredecessorOrNull(),
- OpPrintingFlags().skipRegions())
- << ">";
-}
-
-inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
- RegionSuccessor successor) {
- if (successor.isParent())
- return os << "<to parent>";
- return os << "<to region #" << successor.getSuccessor()->getRegionNumber()
- << ">";
-}
-
/// This class represents upper and lower bounds on the number of times a region
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
/// zero, but the upper bound may not be known.
@@ -381,6 +356,39 @@ namespace mlir {
inline RegionBranchPoint::RegionBranchPoint(
RegionBranchTerminatorOpInterface predecessor)
: predecessor(predecessor.getOperation()) {}
+
+inline RegionBranchTerminatorOpInterface
+RegionBranchPoint::getTerminatorPredecessorOrNull() const {
+ if (!predecessor)
+ return nullptr;
+ return cast<RegionBranchTerminatorOpInterface>(predecessor);
+}
+
+inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
+ return !(lhs == rhs);
+}
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ RegionBranchPoint point) {
+ if (point.isParent())
+ return os << "<from parent>";
+ return os << "<region #"
+ << point.getTerminatorPredecessorOrNull()
+ ->getParentRegion()
+ ->getRegionNumber()
+ << ", terminator "
+ << OpWithFlags(point.getTerminatorPredecessorOrNull(),
+ OpPrintingFlags().skipRegions())
+ << ">";
+}
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ RegionSuccessor successor) {
+ if (successor.isParent())
+ return os << "<to parent>";
+ return os << "<to region #" << successor.getSuccessor()->getRegionNumber()
+ << ">";
+}
} // namespace mlir
#endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index a485a2c6a610f..627eeb14b9023 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -313,7 +313,7 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
predecessorValues.push_back($_op.getEntrySuccessorOperands(successor)[index]);
continue;
}
- auto terminator = cast<RegionBranchTerminatorOpInterface>(predecessor.getTerminatorPredecessorOrNull());
+ auto terminator = predecessor.getTerminatorPredecessorOrNull();
predecessorValues.push_back(terminator.getSuccessorOperands(successor)[index]);
}
}]>,
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index f86bb55df3ac5..6515e42bb2081 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -639,8 +639,6 @@ void AbstractSparseBackwardDataFlowAnalysis::
visitRegionSuccessorsFromTerminator(
RegionBranchTerminatorOpInterface terminator,
RegionBranchOpInterface branch) {
- assert(isa<RegionBranchTerminatorOpInterface>(terminator) &&
- "expected a `RegionBranchTerminatorOpInterface` op");
assert(terminator->getParentOp() == branch.getOperation() &&
"expected `branch` to be the parent op of `terminator`");
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index ebd4b63145f92..ebf78d8bd60ce 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -439,9 +439,7 @@ RegionBranchOpInterface::getSuccessorOperands(RegionBranchPoint src,
RegionSuccessor dest) {
if (src.isParent())
return getEntrySuccessorOperands(dest);
- auto terminator = cast<RegionBranchTerminatorOpInterface>(
- src.getTerminatorPredecessorOrNull());
- return terminator.getSuccessorOperands(dest);
+ return src.getTerminatorPredecessorOrNull().getSuccessorOperands(dest);
}
static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
More information about the Mlir-commits
mailing list