[Mlir-commits] [mlir] 4dd744a - Reland "[mlir] Use a type for representing branch points in `RegionBranchOpInterface`"
Markus Böck
llvmlistbot at llvm.org
Wed Aug 30 00:32:01 PDT 2023
Author: Markus Böck
Date: 2023-08-30T09:31:54+02:00
New Revision: 4dd744ac9c0f772a61dd91c84bc14d17e69aec51
URL: https://github.com/llvm/llvm-project/commit/4dd744ac9c0f772a61dd91c84bc14d17e69aec51
DIFF: https://github.com/llvm/llvm-project/commit/4dd744ac9c0f772a61dd91c84bc14d17e69aec51.diff
LOG: Reland "[mlir] Use a type for representing branch points in `RegionBranchOpInterface`"
This reverts commit b26bb30b467b996c9786e3bd426c07684d84d406.
Added:
Modified:
flang/lib/Optimizer/Dialect/FIROps.cpp
mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Interfaces/ControlFlowInterfaces.cpp
mlir/lib/Transforms/RemoveDeadValues.cpp
mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index bbe06577c27e7b..80567b19f9fe5e 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3467,10 +3467,10 @@ void fir::IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
/// return the successor regions. These are the regions that may be selected
/// during the flow of control.
void fir::IfOp::getSuccessorRegions(
- std::optional<unsigned> index,
+ mlir::RegionBranchPoint point,
llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) {
// The `then` and the `else` region branch back to the parent operation.
- if (index) {
+ if (!point.isParent()) {
regions.push_back(mlir::RegionSuccessor(getResults()));
return;
}
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index a3a558f7705074..6a1335bab8bf6e 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -353,8 +353,8 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// any effect on the lattice that isn't already expressed by the interface
/// itself.
virtual void visitRegionBranchControlFlowTransfer(
- RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
- std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
+ RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
+ RegionBranchPoint regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
meet(before, after);
}
@@ -382,7 +382,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// of the branch operation itself.
void visitRegionBranchOperation(ProgramPoint point,
RegionBranchOpInterface branch,
- std::optional<unsigned> regionNo,
+ RegionBranchPoint branchPoint,
AbstractDenseLattice *before);
/// Visit an operation for which the data flow is described by the
@@ -472,9 +472,8 @@ class DenseBackwardDataFlowAnalysis
/// nullptr`. The behavior can be further refined for specific pairs of "from"
/// and "to" regions.
virtual void visitRegionBranchControlFlowTransfer(
- RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
- std::optional<unsigned> regionTo, const LatticeT &after,
- LatticeT *before) {
+ RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
+ RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
branch, regionFrom, regionTo, after, before);
}
@@ -508,8 +507,8 @@ class DenseBackwardDataFlowAnalysis
static_cast<LatticeT *>(before));
}
void visitRegionBranchControlFlowTransfer(
- RegionBranchOpInterface branch, std::optional<unsigned> regionForm,
- std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
+ RegionBranchOpInterface branch, RegionBranchPoint regionForm,
+ RegionBranchPoint regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) final {
visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
static_cast<const LatticeT &>(after),
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 13dacff3aa0422..5a9a36159b56c5 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -243,7 +243,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// regions or the parent operation itself, and set either the argument or
/// parent result lattices.
void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
- std::optional<unsigned> successorIndex,
+ RegionBranchPoint successor,
ArrayRef<AbstractSparseLattice *> lattices);
};
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index bd81da41ed43cd..006aedced839f9 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -190,6 +190,68 @@ class RegionSuccessor {
ValueRange inputs;
};
+/// This class represents a point being branched from in the methods of the
+/// `RegionBranchOpInterface`.
+/// One can branch from one of two kinds of places:
+/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
+/// * A region within the parent operation.
+class RegionBranchPoint {
+public:
+ /// Returns an instance of `RegionBranchPoint` representing the parent
+ /// operation.
+ static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); }
+
+ /// Creates a `RegionBranchPoint` that branches from the given region.
+ /// The pointer must not be null.
+ RegionBranchPoint(Region *region) : maybeRegion(region) {
+ assert(region && "Region must not be null");
+ }
+
+ RegionBranchPoint(Region ®ion) : RegionBranchPoint(®ion) {}
+
+ /// Explicitly stops users from constructing with `nullptr`.
+ RegionBranchPoint(std::nullptr_t) = delete;
+
+ /// Constructs a `RegionBranchPoint` from the the target of a
+ /// `RegionSuccessor` instance.
+ RegionBranchPoint(RegionSuccessor successor) {
+ if (successor.isParent())
+ maybeRegion = nullptr;
+ else
+ maybeRegion = successor.getSuccessor();
+ }
+
+ /// Assigns a region being branched from.
+ RegionBranchPoint &operator=(Region ®ion) {
+ maybeRegion = ®ion;
+ return *this;
+ }
+
+ /// Returns true if branching from the parent op.
+ bool isParent() const { return maybeRegion == nullptr; }
+
+ /// Returns the region if branching from a region.
+ /// A null pointer otherwise.
+ Region *getRegionOrNull() const { return maybeRegion; }
+
+ /// Returns true if the two branch points are equal.
+ friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
+ return lhs.maybeRegion == rhs.maybeRegion;
+ }
+
+private:
+ // Private constructor to encourage the use of `RegionBranchPoint::parent`.
+ constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
+
+ /// Internal encoding. Uses nullptr for representing branching from the parent
+ /// op and the region being branched from otherwise.
+ Region *maybeRegion;
+};
+
+inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
+ return !(lhs == rhs);
+}
+
/// This class represents upper and lower bounds on the number of times a region
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
/// zero, but the upper bound may not be known.
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 132bd6d53d923a..e52636a5ac8fcc 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -133,14 +133,14 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
let methods = [
InterfaceMethod<[{
Returns the operands of this operation used as the entry arguments when
- entering the region at `index`, which was specified as a successor of
+ branching from `point`, which was specified as a successor of
this operation by `getEntrySuccessorRegions`, or the operands forwarded
to the operation's results when it branches back to itself. These operands
should correspond 1-1 with the successor inputs specified in
`getEntrySuccessorRegions`.
}],
"::mlir::OperandRange", "getEntrySuccessorOperands",
- (ins "::std::optional<unsigned>":$index), [{}],
+ (ins "::mlir::RegionBranchPoint":$point), [{}],
/*defaultImplementation=*/[{
auto operandEnd = this->getOperation()->operand_end();
return ::mlir::OperandRange(operandEnd, operandEnd);
@@ -162,22 +162,20 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
(ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
[{}], [{
- $_op.getSuccessorRegions(std::nullopt, regions);
+ $_op.getSuccessorRegions(mlir::RegionBranchPoint::parent(), regions);
}]
>,
InterfaceMethod<[{
- Returns the viable successors of a region at `index`, or the possible
- successors when branching from the parent op if `index` is None. These
- are the regions that may be selected during the flow of control. The
- parent operation, i.e. a null `index`, may specify itself as successor,
- which indicates that the control flow may not enter any region at all.
- This method allows for describing which regions may be executed when
- entering an operation, and which regions are executed after having
- executed another region of the parent op. The successor region must be
- non-empty.
+ Returns the viable successors of `point`. These are the regions that may
+ be selected during the flow of control. The parent operation, may
+ specify itself as successor, which indicates that the control flow may
+ not enter any region at all. This method allows for describing which
+ regions may be executed when entering an operation, and which regions
+ are executed after having executed another region of the parent op. The
+ successor region must be non-empty.
}],
"void", "getSuccessorRegions",
- (ins "::std::optional<unsigned>":$index,
+ (ins "::mlir::RegionBranchPoint":$point,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
>,
InterfaceMethod<[{
@@ -245,12 +243,10 @@ def RegionBranchTerminatorOpInterface :
let methods = [
InterfaceMethod<[{
Returns a mutable range of operands that are semantically "returned" by
- passing them to the region successor given by `index`. If `index` is
- None, this function returns the operands that are passed as a result to
- the parent operation.
+ passing them to the region successor given by `point`.
}],
"::mlir::MutableOperandRange", "getMutableSuccessorOperands",
- (ins "::std::optional<unsigned>":$index)
+ (ins "::mlir::RegionBranchPoint":$point)
>,
InterfaceMethod<[{
Returns the viable region successors that are branched to after this
@@ -269,8 +265,7 @@ def RegionBranchTerminatorOpInterface :
[{
::mlir::Operation *op = $_op;
::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
- .getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
- regions);
+ .getSuccessorRegions(op->getParentRegion(), regions);
}]
>,
];
@@ -290,8 +285,8 @@ def RegionBranchTerminatorOpInterface :
// them to the region successor given by `index`. If `index` is None, this
// function returns the operands that are passed as a result to the parent
// operation.
- ::mlir::OperandRange getSuccessorOperands(std::optional<unsigned> index) {
- return getMutableSuccessorOperands(index);
+ ::mlir::OperandRange getSuccessorOperands(::mlir::RegionBranchPoint point) {
+ return getMutableSuccessorOperands(point);
}
}];
}
@@ -309,7 +304,7 @@ def ReturnLike : TraitList<[
/*extraOpDeclaration=*/"",
/*extraOpDefinition=*/[{
::mlir::MutableOperandRange $cppClass::getMutableSuccessorOperands(
- ::std::optional<unsigned> index) {
+ ::mlir::RegionBranchPoint point) {
return ::mlir::MutableOperandRange(*this);
}
}]
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index 970e68bc258649..ae2ba90412137c 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -45,9 +45,9 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
// this region predecessor that correspond to the input values of `region`. If
// an index could not be found, std::nullopt is returned instead.
auto getOperandIndexIfPred =
- [&](std::optional<unsigned> predIndex) -> std::optional<unsigned> {
+ [&](RegionBranchPoint pred) -> std::optional<unsigned> {
SmallVector<RegionSuccessor, 2> successors;
- branch.getSuccessorRegions(predIndex, successors);
+ branch.getSuccessorRegions(pred, successors);
for (RegionSuccessor &successor : successors) {
if (successor.getSuccessor() != region)
continue;
@@ -75,28 +75,27 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
};
// Check branches from the parent operation.
- std::optional<unsigned> regionIndex;
- if (region) {
- // Determine the actual region number from the passed region.
- regionIndex = region->getRegionNumber();
- }
+ auto branchPoint = RegionBranchPoint::parent();
+ if (region)
+ branchPoint = region;
+
if (std::optional<unsigned> operandIndex =
- getOperandIndexIfPred(/*predIndex=*/std::nullopt)) {
+ getOperandIndexIfPred(/*predIndex=*/RegionBranchPoint::parent())) {
collectUnderlyingAddressValues(
- branch.getEntrySuccessorOperands(regionIndex)[*operandIndex], maxDepth,
+ branch.getEntrySuccessorOperands(branchPoint)[*operandIndex], maxDepth,
visited, output);
}
// Check branches from each child region.
Operation *op = branch.getOperation();
- for (int i = 0, e = op->getNumRegions(); i != e; ++i) {
- if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(i)) {
- for (Block &block : op->getRegion(i)) {
+ for (Region ®ion : op->getRegions()) {
+ if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(region)) {
+ for (Block &block : region) {
// Try to determine possible region-branch successor operands for the
// current region.
if (auto term = dyn_cast<RegionBranchTerminatorOpInterface>(
block.getTerminator())) {
collectUnderlyingAddressValues(
- term.getSuccessorOperands(regionIndex)[*operandIndex], maxDepth,
+ term.getSuccessorOperands(branchPoint)[*operandIndex], maxDepth,
visited, output);
} else if (block.getNumSuccessors()) {
// Otherwise, if this terminator may exit the region we can't make
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index c79a360e4c11bb..eab408cd5977c3 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -312,7 +312,8 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
// Special cases where control flow may dictate data flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
- return visitRegionBranchOperation(op, branch, std::nullopt, before);
+ return visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(),
+ before);
if (auto call = dyn_cast<CallOpInterface>(op))
return visitCallOperation(call, before);
@@ -368,8 +369,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
// If this block is exiting from an operation with region-based control
// flow, propagate the lattice back along the control flow edge.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
- visitRegionBranchOperation(block, branch,
- block->getParent()->getRegionNumber(), before);
+ visitRegionBranchOperation(block, branch, block->getParent(), before);
return;
}
@@ -396,13 +396,13 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
ProgramPoint point, RegionBranchOpInterface branch,
- std::optional<unsigned> regionNo, AbstractDenseLattice *before) {
+ RegionBranchPoint branchPoint, AbstractDenseLattice *before) {
// The successors of the operation may be either the first operation of the
// entry block of each possible successor region, or the next operation when
// the branch is a successor of itself.
SmallVector<RegionSuccessor> successors;
- branch.getSuccessorRegions(regionNo, successors);
+ branch.getSuccessorRegions(branchPoint, successors);
for (const RegionSuccessor &successor : successors) {
const AbstractDenseLattice *after;
if (successor.isParent() || successor.getSuccessor()->empty()) {
@@ -423,10 +423,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
else
after = getLatticeFor(point, &successorBlock->front());
}
- std::optional<unsigned> successorNo =
- successor.isParent() ? std::optional<unsigned>()
- : successor.getSuccessor()->getRegionNumber();
- visitRegionBranchControlFlowTransfer(branch, regionNo, successorNo, *after,
+
+ visitRegionBranchControlFlowTransfer(branch, branchPoint, successor, *after,
before);
}
}
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 4708cdb042f126..02a0ce1bb29213 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -99,7 +99,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
// The results of a region branch operation are determined by control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
return visitRegionSuccessors({branch}, branch,
- /*successorIndex=*/std::nullopt,
+ /*successor=*/RegionBranchPoint::parent(),
resultLattices);
}
@@ -167,8 +167,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
// Check if the lattices can be determined from region control flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
- return visitRegionSuccessors(
- block, branch, block->getParent()->getRegionNumber(), argLattices);
+ return visitRegionSuccessors(block, branch, block->getParent(),
+ argLattices);
}
// Otherwise, we can't reason about the data-flow.
@@ -212,8 +212,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
ProgramPoint point, RegionBranchOpInterface branch,
- std::optional<unsigned> successorIndex,
- ArrayRef<AbstractSparseLattice *> lattices) {
+ RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
assert(predecessors->allPredecessorsKnown() &&
"unexpected unresolved region successors");
@@ -224,11 +223,11 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
// Check if the predecessor is the parent op.
if (op == branch) {
- operands = branch.getEntrySuccessorOperands(successorIndex);
+ operands = branch.getEntrySuccessorOperands(successor);
// Otherwise, try to deduce the operands from a region return-like op.
} else if (auto regionTerminator =
dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
- operands = regionTerminator.getSuccessorOperands(successorIndex);
+ operands = regionTerminator.getSuccessorOperands(successor);
}
if (!operands) {
@@ -501,10 +500,7 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
BitVector unaccounted(op->getNumOperands(), true);
for (RegionSuccessor &successor : successors) {
- Region *region = successor.getSuccessor();
- OperandRange operands =
- region ? branch.getEntrySuccessorOperands(region->getRegionNumber())
- : branch.getEntrySuccessorOperands({});
+ OperandRange operands = branch.getEntrySuccessorOperands(successor);
MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
ValueRange inputs = successor.getSuccessorInputs();
for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
@@ -538,9 +534,7 @@ void AbstractSparseBackwardDataFlowAnalysis::
for (const RegionSuccessor &successor : successors) {
ValueRange inputs = successor.getSuccessorInputs();
- Region *region = successor.getSuccessor();
- OperandRange operands = terminator.getSuccessorOperands(
- region ? region->getRegionNumber() : std::optional<unsigned>{});
+ OperandRange operands = terminator.getSuccessorOperands(successor);
MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
meet(getLatticeElement(opOperand.get()),
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index bb4aaee21d019e..9d7b8f371a26c6 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2379,9 +2379,9 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// correspond to the loop iterator operands, i.e., those excluding the
/// induction variable. AffineForOp only has one region, so zero is the only
/// valid value for `index`.
-OperandRange
-AffineForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
- assert((!index || *index == 0) && "invalid region index");
+OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert((point.isParent() || point == getLoopBody()) &&
+ "invalid region point");
// The initial operands map to the loop arguments after the induction
// variable or are forwarded to the results when the trip count is zero.
@@ -2394,14 +2394,15 @@ AffineForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void AffineForOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
- assert((!index.has_value() || index.value() == 0) && "expected loop region");
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ assert((point.isParent() || point == getLoopBody()) &&
+ "expected loop region");
// The loop may typically branch back to its body or to the parent operation.
// If the predecessor is the parent op and the trip count is known to be at
// least one, branch into the body using the iterator arguments. And in cases
// we know the trip count is zero, it can only branch back to its parent.
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
- if (!index.has_value() && tripCount.has_value()) {
+ if (point.isParent() && tripCount.has_value()) {
if (tripCount.value() > 0) {
regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
return;
@@ -2414,7 +2415,7 @@ void AffineForOp::getSuccessorRegions(
// From the loop body, if the trip count is one, we can only branch back to
// the parent.
- if (index && tripCount && *tripCount == 1) {
+ if (!point.isParent() && tripCount && *tripCount == 1) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
@@ -2859,10 +2860,10 @@ struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
/// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
/// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
void AffineIfOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
// If the predecessor is an AffineIfOp, then branching into both `then` and
// `else` region is valid.
- if (!index.has_value()) {
+ if (point.isParent()) {
regions.reserve(2);
regions.push_back(
RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 9b4fb81990c169..a05e02faf6d2f0 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -38,9 +38,8 @@ void AsyncDialect::initialize() {
constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
-OperandRange
-ExecuteOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
- assert(index && *index == 0 && "invalid region index");
+OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(point == getBodyRegion() && "invalid region index");
return getBodyOperands();
}
@@ -53,11 +52,10 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
}
-void ExecuteOp::getSuccessorRegions(std::optional<unsigned> index,
+void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The `body` region branch back to the parent operation.
- if (index) {
- assert(*index == 0 && "invalid region index");
+ if (point == getBodyRegion()) {
regions.push_back(RegionSuccessor(getBodyResults()));
return;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index 582974873263d2..9a831e4c322ea0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -372,7 +372,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
// parent operation. In this case, we have to introduce an additional clone
// for buffer that is passed to the argument.
SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
+ regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
successorRegions);
auto *it =
llvm::find_if(successorRegions, [&](RegionSuccessor &successorRegion) {
@@ -383,8 +383,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
// Determine the actual operand to introduce a clone for and rewire the
// operand to point to the clone instead.
- auto operands =
- regionInterface.getEntrySuccessorOperands(argRegion->getRegionNumber());
+ auto operands = regionInterface.getEntrySuccessorOperands(argRegion);
size_t operandIndex =
llvm::find(it->getSuccessorInputs(), blockArg).getIndex() +
operands.getBeginOperandIndex();
@@ -432,8 +431,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
// Query the regionInterface to get all successor regions of the current
// one.
SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(region.getRegionNumber(),
- successorRegions);
+ regionInterface.getSuccessorRegions(region, successorRegions);
// Try to find a matching region successor.
RegionSuccessor *regionSuccessor =
llvm::find_if(successorRegions, regionPredicate);
@@ -445,10 +443,6 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
llvm::find(regionSuccessor->getSuccessorInputs(), argValue)
.getIndex();
- std::optional<unsigned> successorRegionNumber;
- if (Region *successorRegion = regionSuccessor->getSuccessor())
- successorRegionNumber = successorRegion->getRegionNumber();
-
// Iterate over all immediate terminator operations to introduce
// new buffer allocations. Thereby, the appropriate terminator operand
// will be adjusted to point to the newly allocated buffer instead.
@@ -456,8 +450,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
®ion, [&](RegionBranchTerminatorOpInterface terminator) {
// Get the actual mutable operands for this terminator op.
auto terminatorOperands =
- terminator.getMutableSuccessorOperands(
- successorRegionNumber);
+ terminator.getMutableSuccessorOperands(*regionSuccessor);
// Extract the source value from the current terminator.
// This conversion needs to exist on a separate line due to a
// bug in GCC conversion analysis.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index f8231cac778af6..119801f9cc92f3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -123,7 +123,7 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
return true;
// Recurses into all region successors.
SmallVector<RegionSuccessor, 2> successors;
- regionInterface.getSuccessorRegions(current->getRegionNumber(), successors);
+ regionInterface.getSuccessorRegions(current, successors);
for (RegionSuccessor ®ionEntry : successors)
if (recurse(regionEntry.getSuccessor()))
return true;
@@ -132,7 +132,8 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
// Start with all entry regions and test whether they induce a loop.
SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(/*index=*/std::nullopt, successorRegions);
+ regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
+ successorRegions);
for (RegionSuccessor ®ionEntry : successorRegions) {
if (recurse(regionEntry.getSuccessor()))
return true;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index d201e024380661..98a60a48763ab1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -100,16 +100,13 @@ void BufferViewFlowAnalysis::build(Operation *op) {
// Query the RegionBranchOpInterface to find potential successor regions.
// Extract all entry regions and wire all initial entry successor inputs.
SmallVector<RegionSuccessor, 2> entrySuccessors;
- regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
+ regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
entrySuccessors);
for (RegionSuccessor &entrySuccessor : entrySuccessors) {
// Wire the entry region's successor arguments with the initial
// successor inputs.
registerDependencies(
- regionInterface.getEntrySuccessorOperands(
- entrySuccessor.isParent()
- ? std::optional<unsigned>()
- : entrySuccessor.getSuccessor()->getRegionNumber()),
+ regionInterface.getEntrySuccessorOperands(entrySuccessor),
entrySuccessor.getSuccessorInputs());
}
@@ -118,21 +115,16 @@ void BufferViewFlowAnalysis::build(Operation *op) {
// Iterate over all successor region entries that are reachable from the
// current region.
SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(region.getRegionNumber(),
- successorRegions);
+ regionInterface.getSuccessorRegions(region, successorRegions);
for (RegionSuccessor &successorRegion : successorRegions) {
- // Determine the current region index (if any).
- std::optional<unsigned> regionIndex;
- Region *regionSuccessor = successorRegion.getSuccessor();
- if (regionSuccessor)
- regionIndex = regionSuccessor->getRegionNumber();
// Iterate over all immediate terminator operations and wire the
// successor inputs with the successor operands of each terminator.
for (Block &block : region)
if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
block.getTerminator()))
- registerDependencies(terminator.getSuccessorOperands(regionIndex),
- successorRegion.getSuccessorInputs());
+ registerDependencies(
+ terminator.getSuccessorOperands(successorRegion),
+ successorRegion.getSuccessorInputs());
}
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e1b8dd62450a77..9c5c322e23692b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -455,8 +455,8 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
}
void AllocaScopeOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
- if (index) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (!point.isParent()) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 63ce3b2a469627..b573291f0460e6 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -266,9 +266,9 @@ void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ExecuteRegionOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
// If the predecessor is the ExecuteRegionOp, branch into the body.
- if (!index) {
+ if (point.isParent()) {
regions.push_back(RegionSuccessor(&getRegion()));
return;
}
@@ -282,8 +282,8 @@ void ExecuteRegionOp::getSuccessorRegions(
//===----------------------------------------------------------------------===//
MutableOperandRange
-ConditionOp::getMutableSuccessorOperands(std::optional<unsigned> index) {
- assert((!index || index == getParentOp().getAfter().getRegionNumber()) &&
+ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
+ assert((point.isParent() || point == getParentOp().getAfter()) &&
"condition op can only exit the loop or branch to the after"
"region");
// Pass all operands except the condition to the successor region.
@@ -553,7 +553,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
/// Return operands used when entering the region at 'index'. These operands
/// correspond to the loop iterator operands, i.e., those excluding the
/// induction variable.
-OperandRange ForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
+OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
return getInitArgs();
}
@@ -562,7 +562,7 @@ OperandRange ForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
-void ForOp::getSuccessorRegions(std::optional<unsigned> index,
+void ForOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
@@ -1731,7 +1731,7 @@ void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
-void ForallOp::getSuccessorRegions(std::optional<unsigned> index,
+void ForallOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
@@ -2011,10 +2011,10 @@ void IfOp::print(OpAsmPrinter &p) {
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
-void IfOp::getSuccessorRegions(std::optional<unsigned> index,
+void IfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The `then` and the `else` region branch back to the parent operation.
- if (index) {
+ if (!point.isParent()) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
@@ -3042,7 +3042,7 @@ void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ParallelOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
// body.
@@ -3169,8 +3169,8 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
}
-OperandRange WhileOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
- assert(index && *index == 0 &&
+OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(point == getBefore() &&
"WhileOp is expected to branch only to the first region");
return getInits();
@@ -3192,17 +3192,18 @@ Block::BlockArgListType WhileOp::getAfterArguments() {
return getAfterBody()->getArguments();
}
-void WhileOp::getSuccessorRegions(std::optional<unsigned> index,
+void WhileOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The parent op always branches to the condition region.
- if (!index) {
+ if (point.isParent()) {
regions.emplace_back(&getBefore(), getBefore().getArguments());
return;
}
- assert(*index < 2 && "there are only two regions in a WhileOp");
+ assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
+ "there are only two regions in a WhileOp");
// The body region always branches back to the condition region.
- if (*index == 1) {
+ if (point == getAfter()) {
regions.emplace_back(&getBefore(), getBefore().getArguments());
return;
}
@@ -4023,10 +4024,9 @@ Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
}
void IndexSwitchOp::getSuccessorRegions(
- std::optional<unsigned> index,
- SmallVectorImpl<RegionSuccessor> &successors) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
// All regions branch back to the parent op.
- if (index) {
+ if (!point.isParent()) {
successors.emplace_back(getResults());
return;
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index c52a9c5004aaaf..78b06d9ce033f8 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -335,11 +335,11 @@ void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
void AssumingOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
// AssumingOp has unconditional control flow into the region and back to the
// parent, so return the correct RegionSuccessor purely based on the index
// being None or 0.
- if (index) {
+ if (!point.isParent()) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7bc7272b054129..518bfc3931e8d3 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -86,23 +86,25 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
// AlternativesOp
//===----------------------------------------------------------------------===//
-OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
- std::optional<unsigned> index) {
- if (index && getOperation()->getNumOperands() == 1)
+OperandRange
+transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ if (!point.isParent() && getOperation()->getNumOperands() == 1)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
getOperation()->operand_end());
}
void transform::AlternativesOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
for (Region &alternative : llvm::drop_begin(
- getAlternatives(), index.has_value() ? *index + 1 : 0)) {
+ getAlternatives(),
+ point.isParent() ? 0
+ : point.getRegionOrNull()->getRegionNumber() + 1)) {
regions.emplace_back(&alternative, !getOperands().empty()
? alternative.getArguments()
: Block::BlockArgListType());
}
- if (index.has_value())
+ if (!point.isParent())
regions.emplace_back(getOperation()->getResults());
}
@@ -1159,24 +1161,24 @@ void transform::ForeachOp::getEffects(
}
void transform::ForeachOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
Region *bodyRegion = &getBody();
- if (!index) {
+ if (point.isParent()) {
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
return;
}
// Branch back to the region or the parent.
- assert(*index == 0 && "unexpected region index");
+ assert(point == getBody() && "unexpected region index");
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
regions.emplace_back();
}
OperandRange
-transform::ForeachOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
+transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
// The iteration variable op handle is mapped to a subset (one op to be
// precise) of the payload ops of the ForeachOp operand.
- assert(index && *index == 0 && "unexpected region index");
+ assert(point == getBody() && "unexpected region index");
return getOperation()->getOperands();
}
@@ -2178,9 +2180,9 @@ void transform::SequenceOp::getEffects(
getPotentialTopLevelEffects(effects);
}
-OperandRange transform::SequenceOp::getEntrySuccessorOperands(
- std::optional<unsigned> index) {
- assert(index && *index == 0 && "unexpected region index");
+OperandRange
+transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(point == getBody() && "unexpected region index");
if (getOperation()->getNumOperands() > 0)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
@@ -2188,8 +2190,8 @@ OperandRange transform::SequenceOp::getEntrySuccessorOperands(
}
void transform::SequenceOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
- if (!index) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (point.isParent()) {
Region *bodyRegion = &getBody();
regions.emplace_back(bodyRegion, getNumOperands() != 0
? bodyRegion->getArguments()
@@ -2197,7 +2199,7 @@ void transform::SequenceOp::getSuccessorRegions(
return;
}
- assert(*index == 0 && "unexpected region index");
+ assert(point == getBody() && "unexpected region index");
regions.emplace_back(getOperation()->getResults());
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 4e9364611b257d..88bda3931a5a11 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5821,8 +5821,8 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
}
void WarpExecuteOnLane0Op::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
- if (index) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (!point.isParent()) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index cc90da370de693..b3166155e84f93 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -84,18 +84,18 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
// RegionBranchOpInterface
//===----------------------------------------------------------------------===//
-static InFlightDiagnostic &
-printRegionEdgeName(InFlightDiagnostic &diag, std::optional<unsigned> sourceNo,
- std::optional<unsigned> succRegionNo) {
+static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
+ RegionBranchPoint sourceNo,
+ RegionBranchPoint succRegionNo) {
diag << "from ";
- if (sourceNo)
- diag << "Region #" << sourceNo.value();
+ if (Region *region = sourceNo.getRegionOrNull())
+ diag << "Region #" << region->getRegionNumber();
else
diag << "parent operands";
diag << " to ";
- if (succRegionNo)
- diag << "Region #" << succRegionNo.value();
+ if (Region *region = succRegionNo.getRegionOrNull())
+ diag << "Region #" << region->getRegionNumber();
else
diag << "parent results";
return diag;
@@ -107,28 +107,24 @@ printRegionEdgeName(InFlightDiagnostic &diag, std::optional<unsigned> sourceNo,
/// inputs that flow from `sourceIndex' to the given region, or std::nullopt if
/// the exact type match verification is not necessary (e.g., if the Op verifies
/// the match itself).
-static LogicalResult verifyTypesAlongAllEdges(
- Operation *op, std::optional<unsigned> sourceNo,
- function_ref<FailureOr<TypeRange>(std::optional<unsigned>)>
- getInputsTypesForRegion) {
+static LogicalResult
+verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
+ function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
+ getInputsTypesForRegion) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
SmallVector<RegionSuccessor, 2> successors;
- regionInterface.getSuccessorRegions(sourceNo, successors);
+ regionInterface.getSuccessorRegions(sourcePoint, successors);
for (RegionSuccessor &succ : successors) {
- std::optional<unsigned> succRegionNo;
- if (!succ.isParent())
- succRegionNo = succ.getSuccessor()->getRegionNumber();
-
- FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
+ FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
if (failed(sourceTypes))
return failure();
TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
if (sourceTypes->size() != succInputsTypes.size()) {
InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
- return printRegionEdgeName(diag, sourceNo, succRegionNo)
+ return printRegionEdgeName(diag, sourcePoint, succ)
<< ": source has " << sourceTypes->size()
<< " operands, but target successor needs "
<< succInputsTypes.size();
@@ -140,7 +136,7 @@ static LogicalResult verifyTypesAlongAllEdges(
Type inputType = std::get<1>(typesIdx.value());
if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
- return printRegionEdgeName(diag, sourceNo, succRegionNo)
+ return printRegionEdgeName(diag, sourcePoint, succ)
<< ": source type #" << typesIdx.index() << " " << sourceType
<< " should match input type #" << typesIdx.index() << " "
<< inputType;
@@ -154,13 +150,13 @@ static LogicalResult verifyTypesAlongAllEdges(
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
- auto inputTypesFromParent =
- [&](std::optional<unsigned> regionNo) -> TypeRange {
+ auto inputTypesFromParent = [&](RegionBranchPoint regionNo) -> TypeRange {
return regionInterface.getEntrySuccessorOperands(regionNo).getTypes();
};
// Verify types along control flow edges originating from the parent.
- if (failed(verifyTypesAlongAllEdges(op, std::nullopt, inputTypesFromParent)))
+ if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(),
+ inputTypesFromParent)))
return failure();
auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
@@ -176,8 +172,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
};
// Verify types along control flow edges originating from each region.
- for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
- Region ®ion = op->getRegion(regionNo);
+ for (Region ®ion : op->getRegions()) {
// Since there can be multiple terminators implementing the
// `RegionBranchTerminatorOpInterface`, all should have the same operand
@@ -195,7 +190,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
continue;
auto inputTypesForRegion =
- [&](std::optional<unsigned> succRegionNo) -> FailureOr<TypeRange> {
+ [&](RegionBranchPoint succRegionNo) -> FailureOr<TypeRange> {
std::optional<OperandRange> regionReturnOperands;
for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
auto terminatorOperands =
@@ -211,7 +206,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
if (!areTypesCompatible(regionReturnOperands->getTypes(),
terminatorOperands.getTypes())) {
InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
- return printRegionEdgeName(diag, regionNo, succRegionNo)
+ return printRegionEdgeName(diag, region, succRegionNo)
<< " operands mismatch between return-like terminators";
}
}
@@ -220,7 +215,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
return TypeRange(regionReturnOperands->getTypes());
};
- if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesForRegion)))
+ if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
return failure();
}
@@ -237,24 +232,24 @@ static bool isRegionReachable(Region *begin, Region *r) {
visited[begin->getRegionNumber()] = true;
// Retrieve all successors of the region and enqueue them in the worklist.
- SmallVector<unsigned> worklist;
- auto enqueueAllSuccessors = [&](unsigned index) {
+ SmallVector<Region *> worklist;
+ auto enqueueAllSuccessors = [&](Region *region) {
SmallVector<RegionSuccessor> successors;
- op.getSuccessorRegions(index, successors);
+ op.getSuccessorRegions(region, successors);
for (RegionSuccessor successor : successors)
if (!successor.isParent())
- worklist.push_back(successor.getSuccessor()->getRegionNumber());
+ worklist.push_back(successor.getSuccessor());
};
- enqueueAllSuccessors(begin->getRegionNumber());
+ enqueueAllSuccessors(begin);
// Process all regions in the worklist via DFS.
while (!worklist.empty()) {
- unsigned nextRegion = worklist.pop_back_val();
- if (nextRegion == r->getRegionNumber())
+ Region *nextRegion = worklist.pop_back_val();
+ if (nextRegion == r)
return true;
- if (visited[nextRegion])
+ if (visited[nextRegion->getRegionNumber()])
continue;
- visited[nextRegion] = true;
+ visited[nextRegion->getRegionNumber()] = true;
enqueueAllSuccessors(nextRegion);
}
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index ce19dc667f009d..19a84db34dce02 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -316,15 +316,11 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// Return the successors of `region` if the latter is not null. Else return
// the successors of `regionBranchOp`.
auto getSuccessors = [&](Region *region = nullptr) {
- std::optional<unsigned> index =
- region ? std::optional(region->getRegionNumber()) : std::nullopt;
+ auto point = region ? region : RegionBranchPoint::parent();
SmallVector<Attribute> operandAttributes(regionBranchOp->getNumOperands(),
nullptr);
SmallVector<RegionSuccessor> successors;
- if (!index)
- regionBranchOp.getEntrySuccessorRegions(operandAttributes, successors);
- else
- regionBranchOp.getSuccessorRegions(index, successors);
+ regionBranchOp.getSuccessorRegions(point, successors);
return successors;
};
@@ -333,14 +329,10 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// forwarded to `successor`.
auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
Operation *terminator = nullptr) {
- Region *successorRegion = successor.getSuccessor();
- std::optional<unsigned> index =
- successorRegion ? std::optional(successorRegion->getRegionNumber())
- : std::nullopt;
OperandRange operands =
terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
- .getSuccessorOperands(index)
- : regionBranchOp.getEntrySuccessorOperands(index);
+ .getSuccessorOperands(successor)
+ : regionBranchOp.getEntrySuccessorOperands(successor);
SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
return opOperands;
};
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
index a33b523d5d192f..8bfd01d828060a 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
@@ -60,8 +60,8 @@ class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
NextAccess *before) override;
void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
- std::optional<unsigned> regionFrom,
- std::optional<unsigned> regionTo,
+ RegionBranchPoint regionFrom,
+ RegionBranchPoint regionTo,
const NextAccess &after,
NextAccess *before) override;
@@ -124,15 +124,15 @@ void NextAccessAnalysis::visitCallControlFlowTransfer(
}
void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
- RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
- std::optional<unsigned> regionTo, const NextAccess &after,
- NextAccess *before) {
+ RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
+ RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) {
auto testStoreWithARegion =
dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
if (testStoreWithARegion &&
- ((!regionTo && !testStoreWithARegion.getStoreBeforeRegion()) ||
- (!regionFrom && testStoreWithARegion.getStoreBeforeRegion()))) {
+ ((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) ||
+ (regionFrom.isParent() &&
+ testStoreWithARegion.getStoreBeforeRegion()))) {
visitOperation(branch, static_cast<const NextAccess &>(after),
static_cast<NextAccess *>(before));
} else {
@@ -219,7 +219,7 @@ struct TestNextAccessPass
SmallVector<Attribute> entryPointNextAccess;
SmallVector<RegionSuccessor> regionSuccessors;
- iface.getSuccessorRegions(std::nullopt, regionSuccessors);
+ iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors);
for (const RegionSuccessor &successor : regionSuccessors) {
if (!successor.getSuccessor() || successor.getSuccessor()->empty())
continue;
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 57a6ab387281dc..34ed7a1a66fe33 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -931,17 +931,17 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
parser.getCurrentLocation(), result.operands);
}
-OperandRange
-RegionIfOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
- assert(index && *index < 2 && "invalid region index");
+OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
+ "invalid region index");
return getOperands();
}
void RegionIfOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
// We always branch to the join region.
- if (index.has_value()) {
- if (index.value() < 2)
+ if (!point.isParent()) {
+ if (point != getJoinRegion())
regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
else
regions.push_back(RegionSuccessor(getResults()));
@@ -964,11 +964,11 @@ void RegionIfOp::getRegionInvocationBounds(
// AnyCondOp
//===----------------------------------------------------------------------===//
-void AnyCondOp::getSuccessorRegions(std::optional<unsigned> index,
+void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The parent op branches into the only region, and the region branches back
// to the parent op.
- if (!index)
+ if (point.isParent())
regions.emplace_back(&getRegion());
else
regions.emplace_back(getResults());
@@ -985,17 +985,16 @@ void AnyCondOp::getRegionInvocationBounds(
//===----------------------------------------------------------------------===//
void LoopBlockOp::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
regions.emplace_back(&getBody(), getBody().getArguments());
- if (!index)
+ if (point.isParent())
return;
regions.emplace_back((*this)->getResults());
}
-OperandRange
-LoopBlockOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
- assert(index == 0);
+OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(point == getBody());
return getInitMutable();
}
@@ -1003,10 +1002,9 @@ LoopBlockOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
// LoopBlockTerminatorOp
//===----------------------------------------------------------------------===//
-MutableOperandRange LoopBlockTerminatorOp::getMutableSuccessorOperands(
- std::optional<unsigned> index) {
- assert(!index || index == 0);
- if (!index)
+MutableOperandRange
+LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
+ if (point.isParent())
return getExitArgMutable();
return getNextIterArgMutable();
}
@@ -1313,12 +1311,11 @@ MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
}
void TestStoreWithARegion::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
- if (!index) {
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (point.isParent())
regions.emplace_back(&getBody(), getBody().front().getArguments());
- } else {
+ else
regions.emplace_back();
- }
}
LogicalResult
diff --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
index a507baa6445d97..f1aae15393fd3f 100644
--- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
@@ -37,7 +37,7 @@ struct MutuallyExclusiveRegionsOp
}
// Regions have no successors.
- void getSuccessorRegions(std::optional<unsigned> index,
+ void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {}
};
@@ -51,14 +51,13 @@ struct LoopRegionsOp
static StringRef getOperationName() { return "cftest.loop_regions_op"; }
- void getSuccessorRegions(std::optional<unsigned> index,
+ void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
- if (index) {
- if (*index == 1)
+ if (Region *region = point.getRegionOrNull()) {
+ if (point == (*this)->getRegion(1))
// This region also branches back to the parent.
regions.push_back(RegionSuccessor());
- regions.push_back(
- RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions)));
+ regions.push_back(RegionSuccessor(region));
}
}
};
@@ -74,11 +73,11 @@ struct DoubleLoopRegionsOp
return "cftest.double_loop_regions_op";
}
- void getSuccessorRegions(std::optional<unsigned> index,
+ void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
- if (index.has_value()) {
+ if (Region *region = point.getRegionOrNull()) {
regions.push_back(RegionSuccessor());
- regions.push_back(RegionSuccessor(&getOperation()->getRegion(*index)));
+ regions.push_back(RegionSuccessor(region));
}
}
};
@@ -92,9 +91,9 @@ struct SequentialRegionsOp
static StringRef getOperationName() { return "cftest.sequential_regions_op"; }
// Region 0 has Region 1 as a successor.
- void getSuccessorRegions(std::optional<unsigned> index,
+ void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
- if (index == 0u) {
+ if (point == (*this)->getRegion(0)) {
Operation *thisOp = this->getOperation();
regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
}
More information about the Mlir-commits
mailing list