[Mlir-commits] [mlir] 024f562 - [mlir] Use a type for representing branch points in `RegionBranchOpInterface`
Markus Böck
llvmlistbot at llvm.org
Tue Aug 29 11:13:48 PDT 2023
Author: Markus Böck
Date: 2023-08-29T20:02:23+02:00
New Revision: 024f562da67180b7be1663048c960b26c2cc16f8
URL: https://github.com/llvm/llvm-project/commit/024f562da67180b7be1663048c960b26c2cc16f8
DIFF: https://github.com/llvm/llvm-project/commit/024f562da67180b7be1663048c960b26c2cc16f8.diff
LOG: [mlir] Use a type for representing branch points in `RegionBranchOpInterface`
The current implementation is not very ergonomic or descriptive: It uses `std::optional<unsigned>` where `std::nullopt` represents the parent op and `unsigned` is the region number.
This doesn't give us any useful methods specific to region control flow and makes the code fragile to changes due to now taking the region number into account.
This patch introduces a new type called `RegionBranchPoint`, replacing all uses of `std::optional<unsigned>` in the interface. It can be implicitly constructed from a region or a `RegionSuccessor`, can be compared with a region to check whether the branch point is branching from the parent, adds `isParent` to check whether we are coming from a parent op and adds `RegionSuccessor::parent` as a descriptive way to indicate branching from the parent.
Differential Revision: https://reviews.llvm.org/D159116
Added:
Modified:
mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Interfaces/ControlFlowInterfaces.cpp
mlir/lib/Transforms/RemoveDeadValues.cpp
mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index a3a558f7705074..6a1335bab8bf6e 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -353,8 +353,8 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// any effect on the lattice that isn't already expressed by the interface
/// itself.
virtual void visitRegionBranchControlFlowTransfer(
- RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
- std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
+ RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
+ RegionBranchPoint regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
meet(before, after);
}
@@ -382,7 +382,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// of the branch operation itself.
void visitRegionBranchOperation(ProgramPoint point,
RegionBranchOpInterface branch,
- std::optional<unsigned> regionNo,
+ RegionBranchPoint branchPoint,
AbstractDenseLattice *before);
/// Visit an operation for which the data flow is described by the
@@ -472,9 +472,8 @@ class DenseBackwardDataFlowAnalysis
/// nullptr`. The behavior can be further refined for specific pairs of "from"
/// and "to" regions.
virtual void visitRegionBranchControlFlowTransfer(
- RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
- std::optional<unsigned> regionTo, const LatticeT &after,
- LatticeT *before) {
+ RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
+ RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
branch, regionFrom, regionTo, after, before);
}
@@ -508,8 +507,8 @@ class DenseBackwardDataFlowAnalysis
static_cast<LatticeT *>(before));
}
void visitRegionBranchControlFlowTransfer(
- RegionBranchOpInterface branch, std::optional<unsigned> regionForm,
- std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
+ RegionBranchOpInterface branch, RegionBranchPoint regionForm,
+ RegionBranchPoint regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) final {
visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
static_cast<const LatticeT &>(after),
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 13dacff3aa0422..5a9a36159b56c5 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -243,7 +243,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// regions or the parent operation itself, and set either the argument or
/// parent result lattices.
void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
- std::optional<unsigned> successorIndex,
+ RegionBranchPoint successor,
ArrayRef<AbstractSparseLattice *> lattices);
};
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index bd81da41ed43cd..006aedced839f9 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -190,6 +190,68 @@ class RegionSuccessor {
ValueRange inputs;
};
+/// This class represents a point being branched from in the methods of the
+/// `RegionBranchOpInterface`.
+/// One can branch from one of two kinds of places:
+/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
+/// * A region within the parent operation.
+class RegionBranchPoint {
+public:
+ /// Returns an instance of `RegionBranchPoint` representing the parent
+ /// operation.
+ static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); }
+
+ /// Creates a `RegionBranchPoint` that branches from the given region.
+ /// The pointer must not be null.
+ RegionBranchPoint(Region *region) : maybeRegion(region) {
+ assert(region && "Region must not be null");
+ }
+
+ RegionBranchPoint(Region ®ion) : RegionBranchPoint(®ion) {}
+
+ /// Explicitly stops users from constructing with `nullptr`.
+ RegionBranchPoint(std::nullptr_t) = delete;
+
+ /// Constructs a `RegionBranchPoint` from the the target of a
+ /// `RegionSuccessor` instance.
+ RegionBranchPoint(RegionSuccessor successor) {
+ if (successor.isParent())
+ maybeRegion = nullptr;
+ else
+ maybeRegion = successor.getSuccessor();
+ }
+
+ /// Assigns a region being branched from.
+ RegionBranchPoint &operator=(Region ®ion) {
+ maybeRegion = ®ion;
+ return *this;
+ }
+
+ /// Returns true if branching from the parent op.
+ bool isParent() const { return maybeRegion == nullptr; }
+
+ /// Returns the region if branching from a region.
+ /// A null pointer otherwise.
+ Region *getRegionOrNull() const { return maybeRegion; }
+
+ /// Returns true if the two branch points are equal.
+ friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
+ return lhs.maybeRegion == rhs.maybeRegion;
+ }
+
+private:
+ // Private constructor to encourage the use of `RegionBranchPoint::parent`.
+ constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
+
+ /// Internal encoding. Uses nullptr for representing branching from the parent
+ /// op and the region being branched from otherwise.
+ Region *maybeRegion;
+};
+
+inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
+ return !(lhs == rhs);
+}
+
/// This class represents upper and lower bounds on the number of times a region
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
/// zero, but the upper bound may not be known.
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 132bd6d53d923a..e52636a5ac8fcc 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -133,14 +133,14 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
let methods = [
InterfaceMethod<[{
Returns the operands of this operation used as the entry arguments when
- entering the region at `index`, which was specified as a successor of
+ branching from `point`, which was specified as a successor of
this operation by `getEntrySuccessorRegions`, or the operands forwarded
to the operation's results when it branches back to itself. These operands
should correspond 1-1 with the successor inputs specified in
`getEntrySuccessorRegions`.
}],
"::mlir::OperandRange", "getEntrySuccessorOperands",
- (ins "::std::optional<unsigned>":$index), [{}],
+ (ins "::mlir::RegionBranchPoint":$point), [{}],
/*defaultImplementation=*/[{
auto operandEnd = this->getOperation()->operand_end();
return ::mlir::OperandRange(operandEnd, operandEnd);
@@ -162,22 +162,20 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
(ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
[{}], [{
- $_op.getSuccessorRegions(std::nullopt, regions);
+ $_op.getSuccessorRegions(mlir::RegionBranchPoint::parent(), regions);
}]
>,
InterfaceMethod<[{
- Returns the viable successors of a region at `index`, or the possible
- successors when branching from the parent op if `index` is None. These
- are the regions that may be selected during the flow of control. The
- parent operation, i.e. a null `index`, may specify itself as successor,
- which indicates that the control flow may not enter any region at all.
- This method allows for describing which regions may be executed when
- entering an operation, and which regions are executed after having
- executed another region of the parent op. The successor region must be
- non-empty.
+ Returns the viable successors of `point`. These are the regions that may
+ be selected during the flow of control. The parent operation, may
+ specify itself as successor, which indicates that the control flow may
+ not enter any region at all. This method allows for describing which
+ regions may be executed when entering an operation, and which regions
+ are executed after having executed another region of the parent op. The
+ successor region must be non-empty.
}],
"void", "getSuccessorRegions",
- (ins "::std::optional<unsigned>":$index,
+ (ins "::mlir::RegionBranchPoint":$point,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
>,
InterfaceMethod<[{
@@ -245,12 +243,10 @@ def RegionBranchTerminatorOpInterface :
let methods = [
InterfaceMethod<[{
Returns a mutable range of operands that are semantically "returned" by
- passing them to the region successor given by `index`. If `index` is
- None, this function returns the operands that are passed as a result to
- the parent operation.
+ passing them to the region successor given by `point`.
}],
"::mlir::MutableOperandRange", "getMutableSuccessorOperands",
- (ins "::std::optional<unsigned>":$index)
+ (ins "::mlir::RegionBranchPoint":$point)
>,
InterfaceMethod<[{
Returns the viable region successors that are branched to after this
@@ -269,8 +265,7 @@ def RegionBranchTerminatorOpInterface :
[{
::mlir::Operation *op = $_op;
::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
- .getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
- regions);
+ .getSuccessorRegions(op->getParentRegion(), regions);
}]
>,
];
@@ -290,8 +285,8 @@ def RegionBranchTerminatorOpInterface :
// them to the region successor given by `index`. If `index` is None, this
// function returns the operands that are passed as a result to the parent
// operation.
- ::mlir::OperandRange getSuccessorOperands(std::optional<unsigned> index) {
- return getMutableSuccessorOperands(index);
+ ::mlir::OperandRange getSuccessorOperands(::mlir::RegionBranchPoint point) {
+ return getMutableSuccessorOperands(point);
}
}];
}
@@ -309,7 +304,7 @@ def ReturnLike : TraitList<[
/*extraOpDeclaration=*/"",
/*extraOpDefinition=*/[{
::mlir::MutableOperandRange $cppClass::getMutableSuccessorOperands(
- ::std::optional<unsigned> index) {
+ ::mlir::RegionBranchPoint point) {
return ::mlir::MutableOperandRange(*this);
}
}]
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index 970e68bc258649..ae2ba90412137c 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -45,9 +45,9 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
// this region predecessor that correspond to the input values of `region`. If
// an index could not be found, std::nullopt is returned instead.
auto getOperandIndexIfPred =
- [&](std::optional<unsigned> predIndex) -> std::optional<unsigned> {
+ [&](RegionBranchPoint pred) -> std::optional<unsigned> {
SmallVector<RegionSuccessor, 2> successors;
- branch.getSuccessorRegions(predIndex, successors);
+ branch.getSuccessorRegions(pred, successors);
for (RegionSuccessor &successor : successors) {
if (successor.getSuccessor() != region)
continue;
@@ -75,28 +75,27 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
};
// Check branches from the parent operation.
- std::optional<unsigned> regionIndex;
- if (region) {
- // Determine the actual region number from the passed region.
- regionIndex = region->getRegionNumber();
- }
+ auto branchPoint = RegionBranchPoint::parent();
+ if (region)
+ branchPoint = region;
+
if (std::optional<unsigned> operandIndex =
- getOperandIndexIfPred(/*predIndex=*/std::nullopt)) {
+ getOperandIndexIfPred(/*predIndex=*/RegionBranchPoint::parent())) {
collectUnderlyingAddressValues(
- branch.getEntrySuccessorOperands(regionIndex)[*operandIndex], maxDepth,
+ branch.getEntrySuccessorOperands(branchPoint)[*operandIndex], maxDepth,
visited, output);
}
// Check branches from each child region.
Operation *op = branch.getOperation();
- for (int i = 0, e = op->getNumRegions(); i != e; ++i) {
- if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(i)) {
- for (Block &block : op->getRegion(i)) {
+ for (Region ®ion : op->getRegions()) {
+ if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(region)) {
+ for (Block &block : region) {
// Try to determine possible region-branch successor operands for the
// current region.
if (auto term = dyn_cast<RegionBranchTerminatorOpInterface>(
block.getTerminator())) {
collectUnderlyingAddressValues(
- term.getSuccessorOperands(regionIndex)[*operandIndex], maxDepth,
+ term.getSuccessorOperands(branchPoint)[*operandIndex], maxDepth,
visited, output);
} else if (block.getNumSuccessors()) {
// Otherwise, if this terminator may exit the region we can't make
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index c79a360e4c11bb..eab408cd5977c3 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -312,7 +312,8 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
// Special cases where control flow may dictate data flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
- return visitRegionBranchOperation(op, branch, std::nullopt, before);
+ return visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(),
+ before);
if (auto call = dyn_cast<CallOpInterface>(op))
return visitCallOperation(call, before);
@@ -368,8 +369,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
// If this block is exiting from an operation with region-based control
// flow, propagate the lattice back along the control flow edge.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
- visitRegionBranchOperation(block, branch,
- block->getParent()->getRegionNumber(), before);
+ visitRegionBranchOperation(block, branch, block->getParent(), before);
return;
}
@@ -396,13 +396,13 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
ProgramPoint point, RegionBranchOpInterface branch,
- std::optional<unsigned> regionNo, AbstractDenseLattice *before) {
+ RegionBranchPoint branchPoint, AbstractDenseLattice *before) {
// The successors of the operation may be either the first operation of the
// entry block of each possible successor region, or the next operation when
// the branch is a successor of itself.
SmallVector<RegionSuccessor> successors;
- branch.getSuccessorRegions(regionNo, successors);
+ branch.getSuccessorRegions(branchPoint, successors);
for (const RegionSuccessor &successor : successors) {
const AbstractDenseLattice *after;
if (successor.isParent() || successor.getSuccessor()->empty()) {
@@ -423,10 +423,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
else
after = getLatticeFor(point, &successorBlock->front());
}
- std::optional<unsigned> successorNo =
- successor.isParent() ? std::optional<unsigned>()
- : successor.getSuccessor()->getRegionNumber();
- visitRegionBranchControlFlowTransfer(branch, regionNo, successorNo, *after,
+
+ visitRegionBranchControlFlowTransfer(branch, branchPoint, successor, *after,
before);
}
}
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 4708cdb042f126..02a0ce1bb29213 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -99,7 +99,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
// The results of a region branch operation are determined by control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
return visitRegionSuccessors({branch}, branch,
- /*successorIndex=*/std::nullopt,
+ /*successor=*/RegionBranchPoint::parent(),
resultLattices);
}
@@ -167,8 +167,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
// Check if the lattices can be determined from region control flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
- return visitRegionSuccessors(
- block, branch, block->getParent()->getRegionNumber(), argLattices);
+ return visitRegionSuccessors(block, branch, block->getParent(),
+ argLattices);
}
// Otherwise, we can't reason about the data-flow.
@@ -212,8 +212,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
ProgramPoint point, RegionBranchOpInterface branch,
- std::optional<unsigned> successorIndex,
- ArrayRef<AbstractSparseLattice *> lattices) {
+ RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
assert(predecessors->allPredecessorsKnown() &&
"unexpected unresolved region successors");
@@ -224,11 +223,11 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
// Check if the predecessor is the parent op.
if (op == branch) {
- operands = branch.getEntrySuccessorOperands(successorIndex);
+ operands = branch.getEntrySuccessorOperands(successor);
// Otherwise, try to deduce the operands from a region return-like op.
} else if (auto regionTerminator =
dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
- operands = regionTerminator.getSuccessorOperands(successorIndex);
+ operands = regionTerminator.getSuccessorOperands(successor);
}
if (!operands) {
@@ -501,10 +500,7 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
BitVector unaccounted(op->getNumOperands(), true);
for (RegionSuccessor &successor : successors) {
- Region *region = successor.getSuccessor();
- OperandRange operands =
- region ? branch.getEntrySuccessorOperands(region->getRegionNumber())
- : branch.getEntrySuccessorOperands({});
+ OperandRange operands = branch.getEntrySuccessorOperands(successor);
MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
ValueRange inputs = successor.getSuccessorInputs();
for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
@@ -538,9 +534,7 @@ void AbstractSparseBackwardDataFlowAnalysis::
for (const RegionSuccessor &successor : successors) {
ValueRange inputs = successor.getSuccessorInputs();
- Region *region = successor.getSuccessor();
- OperandRange operands = terminator.getSuccessorOperands(
- region ? region->getRegionNumber() : std::optional<unsigned>{});
+ OperandRange operands = terminator.getSuccessorOperands(successor);
MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
meet(getLatticeElement(opOperand.get()),
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index bb4aaee21d019e..9d7b8f371a26c6 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2379,9 +2379,9 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// correspond to the loop iterator operands, i.e., those excluding the
/// induction variable. AffineForOp only has one region, so zero is the only
/// valid value for `index`.
-OperandRange
-AffineForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
- assert((!index || *index == 0) && "invalid region index");
+OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert((point.isParent() || point == getLoopBody()) &&
+ "invalid region point");
// The initial operands map to the loop arguments after the induction
// variable or are forwarded to the results when the trip count is zero.
@@ -2394,14 +2394,15 @@ AffineForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void AffineForOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
- assert((!index.has_value() || index.value() == 0) && "expected loop region");
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ assert((point.isParent() || point == getLoopBody()) &&
+ "expected loop region");
// The loop may typically branch back to its body or to the parent operation.
// If the predecessor is the parent op and the trip count is known to be at
// least one, branch into the body using the iterator arguments. And in cases
// we know the trip count is zero, it can only branch back to its parent.
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
- if (!index.has_value() && tripCount.has_value()) {
+ if (point.isParent() && tripCount.has_value()) {
if (tripCount.value() > 0) {
regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
return;
@@ -2414,7 +2415,7 @@ void AffineForOp::getSuccessorRegions(
// From the loop body, if the trip count is one, we can only branch back to
// the parent.
- if (index && tripCount && *tripCount == 1) {
+ if (!point.isParent() && tripCount && *tripCount == 1) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
@@ -2859,10 +2860,10 @@ struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
/// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
/// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
void AffineIfOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
// If the predecessor is an AffineIfOp, then branching into both `then` and
// `else` region is valid.
- if (!index.has_value()) {
+ if (point.isParent()) {
regions.reserve(2);
regions.push_back(
RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 9b4fb81990c169..a05e02faf6d2f0 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -38,9 +38,8 @@ void AsyncDialect::initialize() {
constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
-OperandRange
-ExecuteOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
- assert(index && *index == 0 && "invalid region index");
+OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(point == getBodyRegion() && "invalid region index");
return getBodyOperands();
}
@@ -53,11 +52,10 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
}
-void ExecuteOp::getSuccessorRegions(std::optional<unsigned> index,
+void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The `body` region branch back to the parent operation.
- if (index) {
- assert(*index == 0 && "invalid region index");
+ if (point == getBodyRegion()) {
regions.push_back(RegionSuccessor(getBodyResults()));
return;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index 582974873263d2..9a831e4c322ea0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -372,7 +372,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
// parent operation. In this case, we have to introduce an additional clone
// for buffer that is passed to the argument.
SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
+ regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
successorRegions);
auto *it =
llvm::find_if(successorRegions, [&](RegionSuccessor &successorRegion) {
@@ -383,8 +383,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
// Determine the actual operand to introduce a clone for and rewire the
// operand to point to the clone instead.
- auto operands =
- regionInterface.getEntrySuccessorOperands(argRegion->getRegionNumber());
+ auto operands = regionInterface.getEntrySuccessorOperands(argRegion);
size_t operandIndex =
llvm::find(it->getSuccessorInputs(), blockArg).getIndex() +
operands.getBeginOperandIndex();
@@ -432,8 +431,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
// Query the regionInterface to get all successor regions of the current
// one.
SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(region.getRegionNumber(),
- successorRegions);
+ regionInterface.getSuccessorRegions(region, successorRegions);
// Try to find a matching region successor.
RegionSuccessor *regionSuccessor =
llvm::find_if(successorRegions, regionPredicate);
@@ -445,10 +443,6 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
llvm::find(regionSuccessor->getSuccessorInputs(), argValue)
.getIndex();
- std::optional<unsigned> successorRegionNumber;
- if (Region *successorRegion = regionSuccessor->getSuccessor())
- successorRegionNumber = successorRegion->getRegionNumber();
-
// Iterate over all immediate terminator operations to introduce
// new buffer allocations. Thereby, the appropriate terminator operand
// will be adjusted to point to the newly allocated buffer instead.
@@ -456,8 +450,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
®ion, [&](RegionBranchTerminatorOpInterface terminator) {
// Get the actual mutable operands for this terminator op.
auto terminatorOperands =
- terminator.getMutableSuccessorOperands(
- successorRegionNumber);
+ terminator.getMutableSuccessorOperands(*regionSuccessor);
// Extract the source value from the current terminator.
// This conversion needs to exist on a separate line due to a
// bug in GCC conversion analysis.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index f8231cac778af6..119801f9cc92f3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -123,7 +123,7 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
return true;
// Recurses into all region successors.
SmallVector<RegionSuccessor, 2> successors;
- regionInterface.getSuccessorRegions(current->getRegionNumber(), successors);
+ regionInterface.getSuccessorRegions(current, successors);
for (RegionSuccessor ®ionEntry : successors)
if (recurse(regionEntry.getSuccessor()))
return true;
@@ -132,7 +132,8 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
// Start with all entry regions and test whether they induce a loop.
SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(/*index=*/std::nullopt, successorRegions);
+ regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
+ successorRegions);
for (RegionSuccessor ®ionEntry : successorRegions) {
if (recurse(regionEntry.getSuccessor()))
return true;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index d201e024380661..98a60a48763ab1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -100,16 +100,13 @@ void BufferViewFlowAnalysis::build(Operation *op) {
// Query the RegionBranchOpInterface to find potential successor regions.
// Extract all entry regions and wire all initial entry successor inputs.
SmallVector<RegionSuccessor, 2> entrySuccessors;
- regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
+ regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
entrySuccessors);
for (RegionSuccessor &entrySuccessor : entrySuccessors) {
// Wire the entry region's successor arguments with the initial
// successor inputs.
registerDependencies(
- regionInterface.getEntrySuccessorOperands(
- entrySuccessor.isParent()
- ? std::optional<unsigned>()
- : entrySuccessor.getSuccessor()->getRegionNumber()),
+ regionInterface.getEntrySuccessorOperands(entrySuccessor),
entrySuccessor.getSuccessorInputs());
}
@@ -118,21 +115,16 @@ void BufferViewFlowAnalysis::build(Operation *op) {
// Iterate over all successor region entries that are reachable from the
// current region.
SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(region.getRegionNumber(),
- successorRegions);
+ regionInterface.getSuccessorRegions(region, successorRegions);
for (RegionSuccessor &successorRegion : successorRegions) {
- // Determine the current region index (if any).
- std::optional<unsigned> regionIndex;
- Region *regionSuccessor = successorRegion.getSuccessor();
- if (regionSuccessor)
- regionIndex = regionSuccessor->getRegionNumber();
// Iterate over all immediate terminator operations and wire the
// successor inputs with the successor operands of each terminator.
for (Block &block : region)
if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
block.getTerminator()))
- registerDependencies(terminator.getSuccessorOperands(regionIndex),
- successorRegion.getSuccessorInputs());
+ registerDependencies(
+ terminator.getSuccessorOperands(successorRegion),
+ successorRegion.getSuccessorInputs());
}
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e1b8dd62450a77..9c5c322e23692b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -455,8 +455,8 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
}
void AllocaScopeOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
- if (index) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (!point.isParent()) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 63ce3b2a469627..b573291f0460e6 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -266,9 +266,9 @@ void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ExecuteRegionOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
// If the predecessor is the ExecuteRegionOp, branch into the body.
- if (!index) {
+ if (point.isParent()) {
regions.push_back(RegionSuccessor(&getRegion()));
return;
}
@@ -282,8 +282,8 @@ void ExecuteRegionOp::getSuccessorRegions(
//===----------------------------------------------------------------------===//
MutableOperandRange
-ConditionOp::getMutableSuccessorOperands(std::optional<unsigned> index) {
- assert((!index || index == getParentOp().getAfter().getRegionNumber()) &&
+ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
+ assert((point.isParent() || point == getParentOp().getAfter()) &&
"condition op can only exit the loop or branch to the after"
"region");
// Pass all operands except the condition to the successor region.
@@ -553,7 +553,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
/// Return operands used when entering the region at 'index'. These operands
/// correspond to the loop iterator operands, i.e., those excluding the
/// induction variable.
-OperandRange ForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
+OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
return getInitArgs();
}
@@ -562,7 +562,7 @@ OperandRange ForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
-void ForOp::getSuccessorRegions(std::optional<unsigned> index,
+void ForOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
@@ -1731,7 +1731,7 @@ void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
-void ForallOp::getSuccessorRegions(std::optional<unsigned> index,
+void ForallOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
@@ -2011,10 +2011,10 @@ void IfOp::print(OpAsmPrinter &p) {
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
-void IfOp::getSuccessorRegions(std::optional<unsigned> index,
+void IfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The `then` and the `else` region branch back to the parent operation.
- if (index) {
+ if (!point.isParent()) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
@@ -3042,7 +3042,7 @@ void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ParallelOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
// body.
@@ -3169,8 +3169,8 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
}
-OperandRange WhileOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
- assert(index && *index == 0 &&
+OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(point == getBefore() &&
"WhileOp is expected to branch only to the first region");
return getInits();
@@ -3192,17 +3192,18 @@ Block::BlockArgListType WhileOp::getAfterArguments() {
return getAfterBody()->getArguments();
}
-void WhileOp::getSuccessorRegions(std::optional<unsigned> index,
+void WhileOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The parent op always branches to the condition region.
- if (!index) {
+ if (point.isParent()) {
regions.emplace_back(&getBefore(), getBefore().getArguments());
return;
}
- assert(*index < 2 && "there are only two regions in a WhileOp");
+ assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
+ "there are only two regions in a WhileOp");
// The body region always branches back to the condition region.
- if (*index == 1) {
+ if (point == getAfter()) {
regions.emplace_back(&getBefore(), getBefore().getArguments());
return;
}
@@ -4023,10 +4024,9 @@ Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
}
void IndexSwitchOp::getSuccessorRegions(
- std::optional<unsigned> index,
- SmallVectorImpl<RegionSuccessor> &successors) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
// All regions branch back to the parent op.
- if (index) {
+ if (!point.isParent()) {
successors.emplace_back(getResults());
return;
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index c52a9c5004aaaf..78b06d9ce033f8 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -335,11 +335,11 @@ void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
void AssumingOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
// AssumingOp has unconditional control flow into the region and back to the
// parent, so return the correct RegionSuccessor purely based on the index
// being None or 0.
- if (index) {
+ if (!point.isParent()) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7bc7272b054129..518bfc3931e8d3 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -86,23 +86,25 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
// AlternativesOp
//===----------------------------------------------------------------------===//
-OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
- std::optional<unsigned> index) {
- if (index && getOperation()->getNumOperands() == 1)
+OperandRange
+transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ if (!point.isParent() && getOperation()->getNumOperands() == 1)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
getOperation()->operand_end());
}
void transform::AlternativesOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
for (Region &alternative : llvm::drop_begin(
- getAlternatives(), index.has_value() ? *index + 1 : 0)) {
+ getAlternatives(),
+ point.isParent() ? 0
+ : point.getRegionOrNull()->getRegionNumber() + 1)) {
regions.emplace_back(&alternative, !getOperands().empty()
? alternative.getArguments()
: Block::BlockArgListType());
}
- if (index.has_value())
+ if (!point.isParent())
regions.emplace_back(getOperation()->getResults());
}
@@ -1159,24 +1161,24 @@ void transform::ForeachOp::getEffects(
}
void transform::ForeachOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
Region *bodyRegion = &getBody();
- if (!index) {
+ if (point.isParent()) {
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
return;
}
// Branch back to the region or the parent.
- assert(*index == 0 && "unexpected region index");
+ assert(point == getBody() && "unexpected region index");
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
regions.emplace_back();
}
OperandRange
-transform::ForeachOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
+transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
// The iteration variable op handle is mapped to a subset (one op to be
// precise) of the payload ops of the ForeachOp operand.
- assert(index && *index == 0 && "unexpected region index");
+ assert(point == getBody() && "unexpected region index");
return getOperation()->getOperands();
}
@@ -2178,9 +2180,9 @@ void transform::SequenceOp::getEffects(
getPotentialTopLevelEffects(effects);
}
-OperandRange transform::SequenceOp::getEntrySuccessorOperands(
- std::optional<unsigned> index) {
- assert(index && *index == 0 && "unexpected region index");
+OperandRange
+transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(point == getBody() && "unexpected region index");
if (getOperation()->getNumOperands() > 0)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
@@ -2188,8 +2190,8 @@ OperandRange transform::SequenceOp::getEntrySuccessorOperands(
}
void transform::SequenceOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
- if (!index) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (point.isParent()) {
Region *bodyRegion = &getBody();
regions.emplace_back(bodyRegion, getNumOperands() != 0
? bodyRegion->getArguments()
@@ -2197,7 +2199,7 @@ void transform::SequenceOp::getSuccessorRegions(
return;
}
- assert(*index == 0 && "unexpected region index");
+ assert(point == getBody() && "unexpected region index");
regions.emplace_back(getOperation()->getResults());
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 4e9364611b257d..88bda3931a5a11 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5821,8 +5821,8 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
}
void WarpExecuteOnLane0Op::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
- if (index) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (!point.isParent()) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index cc90da370de693..b3166155e84f93 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -84,18 +84,18 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
// RegionBranchOpInterface
//===----------------------------------------------------------------------===//
-static InFlightDiagnostic &
-printRegionEdgeName(InFlightDiagnostic &diag, std::optional<unsigned> sourceNo,
- std::optional<unsigned> succRegionNo) {
+static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
+ RegionBranchPoint sourceNo,
+ RegionBranchPoint succRegionNo) {
diag << "from ";
- if (sourceNo)
- diag << "Region #" << sourceNo.value();
+ if (Region *region = sourceNo.getRegionOrNull())
+ diag << "Region #" << region->getRegionNumber();
else
diag << "parent operands";
diag << " to ";
- if (succRegionNo)
- diag << "Region #" << succRegionNo.value();
+ if (Region *region = succRegionNo.getRegionOrNull())
+ diag << "Region #" << region->getRegionNumber();
else
diag << "parent results";
return diag;
@@ -107,28 +107,24 @@ printRegionEdgeName(InFlightDiagnostic &diag, std::optional<unsigned> sourceNo,
/// inputs that flow from `sourceIndex' to the given region, or std::nullopt if
/// the exact type match verification is not necessary (e.g., if the Op verifies
/// the match itself).
-static LogicalResult verifyTypesAlongAllEdges(
- Operation *op, std::optional<unsigned> sourceNo,
- function_ref<FailureOr<TypeRange>(std::optional<unsigned>)>
- getInputsTypesForRegion) {
+static LogicalResult
+verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
+ function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
+ getInputsTypesForRegion) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
SmallVector<RegionSuccessor, 2> successors;
- regionInterface.getSuccessorRegions(sourceNo, successors);
+ regionInterface.getSuccessorRegions(sourcePoint, successors);
for (RegionSuccessor &succ : successors) {
- std::optional<unsigned> succRegionNo;
- if (!succ.isParent())
- succRegionNo = succ.getSuccessor()->getRegionNumber();
-
- FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
+ FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
if (failed(sourceTypes))
return failure();
TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
if (sourceTypes->size() != succInputsTypes.size()) {
InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
- return printRegionEdgeName(diag, sourceNo, succRegionNo)
+ return printRegionEdgeName(diag, sourcePoint, succ)
<< ": source has " << sourceTypes->size()
<< " operands, but target successor needs "
<< succInputsTypes.size();
@@ -140,7 +136,7 @@ static LogicalResult verifyTypesAlongAllEdges(
Type inputType = std::get<1>(typesIdx.value());
if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
- return printRegionEdgeName(diag, sourceNo, succRegionNo)
+ return printRegionEdgeName(diag, sourcePoint, succ)
<< ": source type #" << typesIdx.index() << " " << sourceType
<< " should match input type #" << typesIdx.index() << " "
<< inputType;
@@ -154,13 +150,13 @@ static LogicalResult verifyTypesAlongAllEdges(
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
- auto inputTypesFromParent =
- [&](std::optional<unsigned> regionNo) -> TypeRange {
+ auto inputTypesFromParent = [&](RegionBranchPoint regionNo) -> TypeRange {
return regionInterface.getEntrySuccessorOperands(regionNo).getTypes();
};
// Verify types along control flow edges originating from the parent.
- if (failed(verifyTypesAlongAllEdges(op, std::nullopt, inputTypesFromParent)))
+ if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(),
+ inputTypesFromParent)))
return failure();
auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
@@ -176,8 +172,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
};
// Verify types along control flow edges originating from each region.
- for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
- Region ®ion = op->getRegion(regionNo);
+ for (Region ®ion : op->getRegions()) {
// Since there can be multiple terminators implementing the
// `RegionBranchTerminatorOpInterface`, all should have the same operand
@@ -195,7 +190,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
continue;
auto inputTypesForRegion =
- [&](std::optional<unsigned> succRegionNo) -> FailureOr<TypeRange> {
+ [&](RegionBranchPoint succRegionNo) -> FailureOr<TypeRange> {
std::optional<OperandRange> regionReturnOperands;
for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
auto terminatorOperands =
@@ -211,7 +206,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
if (!areTypesCompatible(regionReturnOperands->getTypes(),
terminatorOperands.getTypes())) {
InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
- return printRegionEdgeName(diag, regionNo, succRegionNo)
+ return printRegionEdgeName(diag, region, succRegionNo)
<< " operands mismatch between return-like terminators";
}
}
@@ -220,7 +215,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
return TypeRange(regionReturnOperands->getTypes());
};
- if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesForRegion)))
+ if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
return failure();
}
@@ -237,24 +232,24 @@ static bool isRegionReachable(Region *begin, Region *r) {
visited[begin->getRegionNumber()] = true;
// Retrieve all successors of the region and enqueue them in the worklist.
- SmallVector<unsigned> worklist;
- auto enqueueAllSuccessors = [&](unsigned index) {
+ SmallVector<Region *> worklist;
+ auto enqueueAllSuccessors = [&](Region *region) {
SmallVector<RegionSuccessor> successors;
- op.getSuccessorRegions(index, successors);
+ op.getSuccessorRegions(region, successors);
for (RegionSuccessor successor : successors)
if (!successor.isParent())
- worklist.push_back(successor.getSuccessor()->getRegionNumber());
+ worklist.push_back(successor.getSuccessor());
};
- enqueueAllSuccessors(begin->getRegionNumber());
+ enqueueAllSuccessors(begin);
// Process all regions in the worklist via DFS.
while (!worklist.empty()) {
- unsigned nextRegion = worklist.pop_back_val();
- if (nextRegion == r->getRegionNumber())
+ Region *nextRegion = worklist.pop_back_val();
+ if (nextRegion == r)
return true;
- if (visited[nextRegion])
+ if (visited[nextRegion->getRegionNumber()])
continue;
- visited[nextRegion] = true;
+ visited[nextRegion->getRegionNumber()] = true;
enqueueAllSuccessors(nextRegion);
}
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index ce19dc667f009d..19a84db34dce02 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -316,15 +316,11 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// Return the successors of `region` if the latter is not null. Else return
// the successors of `regionBranchOp`.
auto getSuccessors = [&](Region *region = nullptr) {
- std::optional<unsigned> index =
- region ? std::optional(region->getRegionNumber()) : std::nullopt;
+ auto point = region ? region : RegionBranchPoint::parent();
SmallVector<Attribute> operandAttributes(regionBranchOp->getNumOperands(),
nullptr);
SmallVector<RegionSuccessor> successors;
- if (!index)
- regionBranchOp.getEntrySuccessorRegions(operandAttributes, successors);
- else
- regionBranchOp.getSuccessorRegions(index, successors);
+ regionBranchOp.getSuccessorRegions(point, successors);
return successors;
};
@@ -333,14 +329,10 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// forwarded to `successor`.
auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
Operation *terminator = nullptr) {
- Region *successorRegion = successor.getSuccessor();
- std::optional<unsigned> index =
- successorRegion ? std::optional(successorRegion->getRegionNumber())
- : std::nullopt;
OperandRange operands =
terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
- .getSuccessorOperands(index)
- : regionBranchOp.getEntrySuccessorOperands(index);
+ .getSuccessorOperands(successor)
+ : regionBranchOp.getEntrySuccessorOperands(successor);
SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
return opOperands;
};
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
index a33b523d5d192f..8bfd01d828060a 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
@@ -60,8 +60,8 @@ class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
NextAccess *before) override;
void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
- std::optional<unsigned> regionFrom,
- std::optional<unsigned> regionTo,
+ RegionBranchPoint regionFrom,
+ RegionBranchPoint regionTo,
const NextAccess &after,
NextAccess *before) override;
@@ -124,15 +124,15 @@ void NextAccessAnalysis::visitCallControlFlowTransfer(
}
void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
- RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
- std::optional<unsigned> regionTo, const NextAccess &after,
- NextAccess *before) {
+ RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
+ RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) {
auto testStoreWithARegion =
dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
if (testStoreWithARegion &&
- ((!regionTo && !testStoreWithARegion.getStoreBeforeRegion()) ||
- (!regionFrom && testStoreWithARegion.getStoreBeforeRegion()))) {
+ ((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) ||
+ (regionFrom.isParent() &&
+ testStoreWithARegion.getStoreBeforeRegion()))) {
visitOperation(branch, static_cast<const NextAccess &>(after),
static_cast<NextAccess *>(before));
} else {
@@ -219,7 +219,7 @@ struct TestNextAccessPass
SmallVector<Attribute> entryPointNextAccess;
SmallVector<RegionSuccessor> regionSuccessors;
- iface.getSuccessorRegions(std::nullopt, regionSuccessors);
+ iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors);
for (const RegionSuccessor &successor : regionSuccessors) {
if (!successor.getSuccessor() || successor.getSuccessor()->empty())
continue;
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 57a6ab387281dc..34ed7a1a66fe33 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -931,17 +931,17 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
parser.getCurrentLocation(), result.operands);
}
-OperandRange
-RegionIfOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
- assert(index && *index < 2 && "invalid region index");
+OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
+ "invalid region index");
return getOperands();
}
void RegionIfOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
// We always branch to the join region.
- if (index.has_value()) {
- if (index.value() < 2)
+ if (!point.isParent()) {
+ if (point != getJoinRegion())
regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
else
regions.push_back(RegionSuccessor(getResults()));
@@ -964,11 +964,11 @@ void RegionIfOp::getRegionInvocationBounds(
// AnyCondOp
//===----------------------------------------------------------------------===//
-void AnyCondOp::getSuccessorRegions(std::optional<unsigned> index,
+void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The parent op branches into the only region, and the region branches back
// to the parent op.
- if (!index)
+ if (point.isParent())
regions.emplace_back(&getRegion());
else
regions.emplace_back(getResults());
@@ -985,17 +985,16 @@ void AnyCondOp::getRegionInvocationBounds(
//===----------------------------------------------------------------------===//
void LoopBlockOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
regions.emplace_back(&getBody(), getBody().getArguments());
- if (!index)
+ if (point.isParent())
return;
regions.emplace_back((*this)->getResults());
}
-OperandRange
-LoopBlockOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
- assert(index == 0);
+OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(point == getBody());
return getInitMutable();
}
@@ -1003,10 +1002,9 @@ LoopBlockOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
// LoopBlockTerminatorOp
//===----------------------------------------------------------------------===//
-MutableOperandRange LoopBlockTerminatorOp::getMutableSuccessorOperands(
- std::optional<unsigned> index) {
- assert(!index || index == 0);
- if (!index)
+MutableOperandRange
+LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
+ if (point.isParent())
return getExitArgMutable();
return getNextIterArgMutable();
}
@@ -1313,12 +1311,11 @@ MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
}
void TestStoreWithARegion::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
- if (!index) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (point.isParent())
regions.emplace_back(&getBody(), getBody().front().getArguments());
- } else {
+ else
regions.emplace_back();
- }
}
LogicalResult
diff --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
index a507baa6445d97..f1aae15393fd3f 100644
--- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
@@ -37,7 +37,7 @@ struct MutuallyExclusiveRegionsOp
}
// Regions have no successors.
- void getSuccessorRegions(std::optional<unsigned> index,
+ void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {}
};
@@ -51,14 +51,13 @@ struct LoopRegionsOp
static StringRef getOperationName() { return "cftest.loop_regions_op"; }
- void getSuccessorRegions(std::optional<unsigned> index,
+ void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
- if (index) {
- if (*index == 1)
+ if (Region *region = point.getRegionOrNull()) {
+ if (point == (*this)->getRegion(1))
// This region also branches back to the parent.
regions.push_back(RegionSuccessor());
- regions.push_back(
- RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions)));
+ regions.push_back(RegionSuccessor(region));
}
}
};
@@ -74,11 +73,11 @@ struct DoubleLoopRegionsOp
return "cftest.double_loop_regions_op";
}
- void getSuccessorRegions(std::optional<unsigned> index,
+ void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
- if (index.has_value()) {
+ if (Region *region = point.getRegionOrNull()) {
regions.push_back(RegionSuccessor());
- regions.push_back(RegionSuccessor(&getOperation()->getRegion(*index)));
+ regions.push_back(RegionSuccessor(region));
}
}
};
@@ -92,9 +91,9 @@ struct SequentialRegionsOp
static StringRef getOperationName() { return "cftest.sequential_regions_op"; }
// Region 0 has Region 1 as a successor.
- void getSuccessorRegions(std::optional<unsigned> index,
+ void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
- if (index == 0u) {
+ if (point == (*this)->getRegion(0)) {
Operation *thisOp = this->getOperation();
regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
}
More information about the Mlir-commits
mailing list