[Mlir-commits] [mlir] b26bb30 - Revert "[mlir] Use a type for representing branch points in `RegionBranchOpInterface`"
Markus Böck
llvmlistbot at llvm.org
Tue Aug 29 11:46:12 PDT 2023
Author: Markus Böck
Date: 2023-08-29T20:17:50+02:00
New Revision: b26bb30b467b996c9786e3bd426c07684d84d406
URL: https://github.com/llvm/llvm-project/commit/b26bb30b467b996c9786e3bd426c07684d84d406
DIFF: https://github.com/llvm/llvm-project/commit/b26bb30b467b996c9786e3bd426c07684d84d406.diff
LOG: Revert "[mlir] Use a type for representing branch points in `RegionBranchOpInterface`"
This reverts commit 024f562da67180b7be1663048c960b26c2cc16f8.
Forgot to update flang
Added:
Modified:
mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Interfaces/ControlFlowInterfaces.cpp
mlir/lib/Transforms/RemoveDeadValues.cpp
mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 6a1335bab8bf6e..a3a558f7705074 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -353,8 +353,8 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// any effect on the lattice that isn't already expressed by the interface
/// itself.
virtual void visitRegionBranchControlFlowTransfer(
- RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+ RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
+ std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
meet(before, after);
}
@@ -382,7 +382,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// of the branch operation itself.
void visitRegionBranchOperation(ProgramPoint point,
RegionBranchOpInterface branch,
- RegionBranchPoint branchPoint,
+ std::optional<unsigned> regionNo,
AbstractDenseLattice *before);
/// Visit an operation for which the data flow is described by the
@@ -472,8 +472,9 @@ class DenseBackwardDataFlowAnalysis
/// nullptr`. The behavior can be further refined for specific pairs of "from"
/// and "to" regions.
virtual void visitRegionBranchControlFlowTransfer(
- RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
+ RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
+ std::optional<unsigned> regionTo, const LatticeT &after,
+ LatticeT *before) {
AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
branch, regionFrom, regionTo, after, before);
}
@@ -507,8 +508,8 @@ class DenseBackwardDataFlowAnalysis
static_cast<LatticeT *>(before));
}
void visitRegionBranchControlFlowTransfer(
- RegionBranchOpInterface branch, RegionBranchPoint regionForm,
- RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+ RegionBranchOpInterface branch, std::optional<unsigned> regionForm,
+ std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) final {
visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
static_cast<const LatticeT &>(after),
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 5a9a36159b56c5..13dacff3aa0422 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -243,7 +243,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// regions or the parent operation itself, and set either the argument or
/// parent result lattices.
void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
- RegionBranchPoint successor,
+ std::optional<unsigned> successorIndex,
ArrayRef<AbstractSparseLattice *> lattices);
};
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 006aedced839f9..bd81da41ed43cd 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -190,68 +190,6 @@ class RegionSuccessor {
ValueRange inputs;
};
-/// This class represents a point being branched from in the methods of the
-/// `RegionBranchOpInterface`.
-/// One can branch from one of two kinds of places:
-/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
-/// * A region within the parent operation.
-class RegionBranchPoint {
-public:
- /// Returns an instance of `RegionBranchPoint` representing the parent
- /// operation.
- static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); }
-
- /// Creates a `RegionBranchPoint` that branches from the given region.
- /// The pointer must not be null.
- RegionBranchPoint(Region *region) : maybeRegion(region) {
- assert(region && "Region must not be null");
- }
-
- RegionBranchPoint(Region ®ion) : RegionBranchPoint(®ion) {}
-
- /// Explicitly stops users from constructing with `nullptr`.
- RegionBranchPoint(std::nullptr_t) = delete;
-
- /// Constructs a `RegionBranchPoint` from the the target of a
- /// `RegionSuccessor` instance.
- RegionBranchPoint(RegionSuccessor successor) {
- if (successor.isParent())
- maybeRegion = nullptr;
- else
- maybeRegion = successor.getSuccessor();
- }
-
- /// Assigns a region being branched from.
- RegionBranchPoint &operator=(Region ®ion) {
- maybeRegion = ®ion;
- return *this;
- }
-
- /// Returns true if branching from the parent op.
- bool isParent() const { return maybeRegion == nullptr; }
-
- /// Returns the region if branching from a region.
- /// A null pointer otherwise.
- Region *getRegionOrNull() const { return maybeRegion; }
-
- /// Returns true if the two branch points are equal.
- friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
- return lhs.maybeRegion == rhs.maybeRegion;
- }
-
-private:
- // Private constructor to encourage the use of `RegionBranchPoint::parent`.
- constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
-
- /// Internal encoding. Uses nullptr for representing branching from the parent
- /// op and the region being branched from otherwise.
- Region *maybeRegion;
-};
-
-inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
- return !(lhs == rhs);
-}
-
/// This class represents upper and lower bounds on the number of times a region
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
/// zero, but the upper bound may not be known.
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index e52636a5ac8fcc..132bd6d53d923a 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -133,14 +133,14 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
let methods = [
InterfaceMethod<[{
Returns the operands of this operation used as the entry arguments when
- branching from `point`, which was specified as a successor of
+ entering the region at `index`, which was specified as a successor of
this operation by `getEntrySuccessorRegions`, or the operands forwarded
to the operation's results when it branches back to itself. These operands
should correspond 1-1 with the successor inputs specified in
`getEntrySuccessorRegions`.
}],
"::mlir::OperandRange", "getEntrySuccessorOperands",
- (ins "::mlir::RegionBranchPoint":$point), [{}],
+ (ins "::std::optional<unsigned>":$index), [{}],
/*defaultImplementation=*/[{
auto operandEnd = this->getOperation()->operand_end();
return ::mlir::OperandRange(operandEnd, operandEnd);
@@ -162,20 +162,22 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
(ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
[{}], [{
- $_op.getSuccessorRegions(mlir::RegionBranchPoint::parent(), regions);
+ $_op.getSuccessorRegions(std::nullopt, regions);
}]
>,
InterfaceMethod<[{
- Returns the viable successors of `point`. These are the regions that may
- be selected during the flow of control. The parent operation, may
- specify itself as successor, which indicates that the control flow may
- not enter any region at all. This method allows for describing which
- regions may be executed when entering an operation, and which regions
- are executed after having executed another region of the parent op. The
- successor region must be non-empty.
+ Returns the viable successors of a region at `index`, or the possible
+ successors when branching from the parent op if `index` is None. These
+ are the regions that may be selected during the flow of control. The
+ parent operation, i.e. a null `index`, may specify itself as successor,
+ which indicates that the control flow may not enter any region at all.
+ This method allows for describing which regions may be executed when
+ entering an operation, and which regions are executed after having
+ executed another region of the parent op. The successor region must be
+ non-empty.
}],
"void", "getSuccessorRegions",
- (ins "::mlir::RegionBranchPoint":$point,
+ (ins "::std::optional<unsigned>":$index,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
>,
InterfaceMethod<[{
@@ -243,10 +245,12 @@ def RegionBranchTerminatorOpInterface :
let methods = [
InterfaceMethod<[{
Returns a mutable range of operands that are semantically "returned" by
- passing them to the region successor given by `point`.
+ passing them to the region successor given by `index`. If `index` is
+ None, this function returns the operands that are passed as a result to
+ the parent operation.
}],
"::mlir::MutableOperandRange", "getMutableSuccessorOperands",
- (ins "::mlir::RegionBranchPoint":$point)
+ (ins "::std::optional<unsigned>":$index)
>,
InterfaceMethod<[{
Returns the viable region successors that are branched to after this
@@ -265,7 +269,8 @@ def RegionBranchTerminatorOpInterface :
[{
::mlir::Operation *op = $_op;
::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
- .getSuccessorRegions(op->getParentRegion(), regions);
+ .getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
+ regions);
}]
>,
];
@@ -285,8 +290,8 @@ def RegionBranchTerminatorOpInterface :
// them to the region successor given by `index`. If `index` is None, this
// function returns the operands that are passed as a result to the parent
// operation.
- ::mlir::OperandRange getSuccessorOperands(::mlir::RegionBranchPoint point) {
- return getMutableSuccessorOperands(point);
+ ::mlir::OperandRange getSuccessorOperands(std::optional<unsigned> index) {
+ return getMutableSuccessorOperands(index);
}
}];
}
@@ -304,7 +309,7 @@ def ReturnLike : TraitList<[
/*extraOpDeclaration=*/"",
/*extraOpDefinition=*/[{
::mlir::MutableOperandRange $cppClass::getMutableSuccessorOperands(
- ::mlir::RegionBranchPoint point) {
+ ::std::optional<unsigned> index) {
return ::mlir::MutableOperandRange(*this);
}
}]
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index ae2ba90412137c..970e68bc258649 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -45,9 +45,9 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
// this region predecessor that correspond to the input values of `region`. If
// an index could not be found, std::nullopt is returned instead.
auto getOperandIndexIfPred =
- [&](RegionBranchPoint pred) -> std::optional<unsigned> {
+ [&](std::optional<unsigned> predIndex) -> std::optional<unsigned> {
SmallVector<RegionSuccessor, 2> successors;
- branch.getSuccessorRegions(pred, successors);
+ branch.getSuccessorRegions(predIndex, successors);
for (RegionSuccessor &successor : successors) {
if (successor.getSuccessor() != region)
continue;
@@ -75,27 +75,28 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
};
// Check branches from the parent operation.
- auto branchPoint = RegionBranchPoint::parent();
- if (region)
- branchPoint = region;
-
+ std::optional<unsigned> regionIndex;
+ if (region) {
+ // Determine the actual region number from the passed region.
+ regionIndex = region->getRegionNumber();
+ }
if (std::optional<unsigned> operandIndex =
- getOperandIndexIfPred(/*predIndex=*/RegionBranchPoint::parent())) {
+ getOperandIndexIfPred(/*predIndex=*/std::nullopt)) {
collectUnderlyingAddressValues(
- branch.getEntrySuccessorOperands(branchPoint)[*operandIndex], maxDepth,
+ branch.getEntrySuccessorOperands(regionIndex)[*operandIndex], maxDepth,
visited, output);
}
// Check branches from each child region.
Operation *op = branch.getOperation();
- for (Region ®ion : op->getRegions()) {
- if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(region)) {
- for (Block &block : region) {
+ for (int i = 0, e = op->getNumRegions(); i != e; ++i) {
+ if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(i)) {
+ for (Block &block : op->getRegion(i)) {
// Try to determine possible region-branch successor operands for the
// current region.
if (auto term = dyn_cast<RegionBranchTerminatorOpInterface>(
block.getTerminator())) {
collectUnderlyingAddressValues(
- term.getSuccessorOperands(branchPoint)[*operandIndex], maxDepth,
+ term.getSuccessorOperands(regionIndex)[*operandIndex], maxDepth,
visited, output);
} else if (block.getNumSuccessors()) {
// Otherwise, if this terminator may exit the region we can't make
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index eab408cd5977c3..c79a360e4c11bb 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -312,8 +312,7 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
// Special cases where control flow may dictate data flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
- return visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(),
- before);
+ return visitRegionBranchOperation(op, branch, std::nullopt, before);
if (auto call = dyn_cast<CallOpInterface>(op))
return visitCallOperation(call, before);
@@ -369,7 +368,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
// If this block is exiting from an operation with region-based control
// flow, propagate the lattice back along the control flow edge.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
- visitRegionBranchOperation(block, branch, block->getParent(), before);
+ visitRegionBranchOperation(block, branch,
+ block->getParent()->getRegionNumber(), before);
return;
}
@@ -396,13 +396,13 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
ProgramPoint point, RegionBranchOpInterface branch,
- RegionBranchPoint branchPoint, AbstractDenseLattice *before) {
+ std::optional<unsigned> regionNo, AbstractDenseLattice *before) {
// The successors of the operation may be either the first operation of the
// entry block of each possible successor region, or the next operation when
// the branch is a successor of itself.
SmallVector<RegionSuccessor> successors;
- branch.getSuccessorRegions(branchPoint, successors);
+ branch.getSuccessorRegions(regionNo, successors);
for (const RegionSuccessor &successor : successors) {
const AbstractDenseLattice *after;
if (successor.isParent() || successor.getSuccessor()->empty()) {
@@ -423,8 +423,10 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
else
after = getLatticeFor(point, &successorBlock->front());
}
-
- visitRegionBranchControlFlowTransfer(branch, branchPoint, successor, *after,
+ std::optional<unsigned> successorNo =
+ successor.isParent() ? std::optional<unsigned>()
+ : successor.getSuccessor()->getRegionNumber();
+ visitRegionBranchControlFlowTransfer(branch, regionNo, successorNo, *after,
before);
}
}
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 02a0ce1bb29213..4708cdb042f126 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -99,7 +99,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
// The results of a region branch operation are determined by control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
return visitRegionSuccessors({branch}, branch,
- /*successor=*/RegionBranchPoint::parent(),
+ /*successorIndex=*/std::nullopt,
resultLattices);
}
@@ -167,8 +167,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
// Check if the lattices can be determined from region control flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
- return visitRegionSuccessors(block, branch, block->getParent(),
- argLattices);
+ return visitRegionSuccessors(
+ block, branch, block->getParent()->getRegionNumber(), argLattices);
}
// Otherwise, we can't reason about the data-flow.
@@ -212,7 +212,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
ProgramPoint point, RegionBranchOpInterface branch,
- RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
+ std::optional<unsigned> successorIndex,
+ ArrayRef<AbstractSparseLattice *> lattices) {
const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
assert(predecessors->allPredecessorsKnown() &&
"unexpected unresolved region successors");
@@ -223,11 +224,11 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
// Check if the predecessor is the parent op.
if (op == branch) {
- operands = branch.getEntrySuccessorOperands(successor);
+ operands = branch.getEntrySuccessorOperands(successorIndex);
// Otherwise, try to deduce the operands from a region return-like op.
} else if (auto regionTerminator =
dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
- operands = regionTerminator.getSuccessorOperands(successor);
+ operands = regionTerminator.getSuccessorOperands(successorIndex);
}
if (!operands) {
@@ -500,7 +501,10 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
BitVector unaccounted(op->getNumOperands(), true);
for (RegionSuccessor &successor : successors) {
- OperandRange operands = branch.getEntrySuccessorOperands(successor);
+ Region *region = successor.getSuccessor();
+ OperandRange operands =
+ region ? branch.getEntrySuccessorOperands(region->getRegionNumber())
+ : branch.getEntrySuccessorOperands({});
MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
ValueRange inputs = successor.getSuccessorInputs();
for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
@@ -534,7 +538,9 @@ void AbstractSparseBackwardDataFlowAnalysis::
for (const RegionSuccessor &successor : successors) {
ValueRange inputs = successor.getSuccessorInputs();
- OperandRange operands = terminator.getSuccessorOperands(successor);
+ Region *region = successor.getSuccessor();
+ OperandRange operands = terminator.getSuccessorOperands(
+ region ? region->getRegionNumber() : std::optional<unsigned>{});
MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
meet(getLatticeElement(opOperand.get()),
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 9d7b8f371a26c6..bb4aaee21d019e 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2379,9 +2379,9 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// correspond to the loop iterator operands, i.e., those excluding the
/// induction variable. AffineForOp only has one region, so zero is the only
/// valid value for `index`.
-OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert((point.isParent() || point == getLoopBody()) &&
- "invalid region point");
+OperandRange
+AffineForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
+ assert((!index || *index == 0) && "invalid region index");
// The initial operands map to the loop arguments after the induction
// variable or are forwarded to the results when the trip count is zero.
@@ -2394,15 +2394,14 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void AffineForOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- assert((point.isParent() || point == getLoopBody()) &&
- "expected loop region");
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ assert((!index.has_value() || index.value() == 0) && "expected loop region");
// The loop may typically branch back to its body or to the parent operation.
// If the predecessor is the parent op and the trip count is known to be at
// least one, branch into the body using the iterator arguments. And in cases
// we know the trip count is zero, it can only branch back to its parent.
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
- if (point.isParent() && tripCount.has_value()) {
+ if (!index.has_value() && tripCount.has_value()) {
if (tripCount.value() > 0) {
regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
return;
@@ -2415,7 +2414,7 @@ void AffineForOp::getSuccessorRegions(
// From the loop body, if the trip count is one, we can only branch back to
// the parent.
- if (!point.isParent() && tripCount && *tripCount == 1) {
+ if (index && tripCount && *tripCount == 1) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
@@ -2860,10 +2859,10 @@ struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
/// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
/// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
void AffineIfOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
// If the predecessor is an AffineIfOp, then branching into both `then` and
// `else` region is valid.
- if (point.isParent()) {
+ if (!index.has_value()) {
regions.reserve(2);
regions.push_back(
RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index a05e02faf6d2f0..9b4fb81990c169 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -38,8 +38,9 @@ void AsyncDialect::initialize() {
constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
-OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBodyRegion() && "invalid region index");
+OperandRange
+ExecuteOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
+ assert(index && *index == 0 && "invalid region index");
return getBodyOperands();
}
@@ -52,10 +53,11 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
}
-void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
+void ExecuteOp::getSuccessorRegions(std::optional<unsigned> index,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The `body` region branch back to the parent operation.
- if (point == getBodyRegion()) {
+ if (index) {
+ assert(*index == 0 && "invalid region index");
regions.push_back(RegionSuccessor(getBodyResults()));
return;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index 9a831e4c322ea0..582974873263d2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -372,7 +372,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
// parent operation. In this case, we have to introduce an additional clone
// for buffer that is passed to the argument.
SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
+ regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
successorRegions);
auto *it =
llvm::find_if(successorRegions, [&](RegionSuccessor &successorRegion) {
@@ -383,7 +383,8 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
// Determine the actual operand to introduce a clone for and rewire the
// operand to point to the clone instead.
- auto operands = regionInterface.getEntrySuccessorOperands(argRegion);
+ auto operands =
+ regionInterface.getEntrySuccessorOperands(argRegion->getRegionNumber());
size_t operandIndex =
llvm::find(it->getSuccessorInputs(), blockArg).getIndex() +
operands.getBeginOperandIndex();
@@ -431,7 +432,8 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
// Query the regionInterface to get all successor regions of the current
// one.
SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(region, successorRegions);
+ regionInterface.getSuccessorRegions(region.getRegionNumber(),
+ successorRegions);
// Try to find a matching region successor.
RegionSuccessor *regionSuccessor =
llvm::find_if(successorRegions, regionPredicate);
@@ -443,6 +445,10 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
llvm::find(regionSuccessor->getSuccessorInputs(), argValue)
.getIndex();
+ std::optional<unsigned> successorRegionNumber;
+ if (Region *successorRegion = regionSuccessor->getSuccessor())
+ successorRegionNumber = successorRegion->getRegionNumber();
+
// Iterate over all immediate terminator operations to introduce
// new buffer allocations. Thereby, the appropriate terminator operand
// will be adjusted to point to the newly allocated buffer instead.
@@ -450,7 +456,8 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
®ion, [&](RegionBranchTerminatorOpInterface terminator) {
// Get the actual mutable operands for this terminator op.
auto terminatorOperands =
- terminator.getMutableSuccessorOperands(*regionSuccessor);
+ terminator.getMutableSuccessorOperands(
+ successorRegionNumber);
// Extract the source value from the current terminator.
// This conversion needs to exist on a separate line due to a
// bug in GCC conversion analysis.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index 119801f9cc92f3..f8231cac778af6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -123,7 +123,7 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
return true;
// Recurses into all region successors.
SmallVector<RegionSuccessor, 2> successors;
- regionInterface.getSuccessorRegions(current, successors);
+ regionInterface.getSuccessorRegions(current->getRegionNumber(), successors);
for (RegionSuccessor ®ionEntry : successors)
if (recurse(regionEntry.getSuccessor()))
return true;
@@ -132,8 +132,7 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
// Start with all entry regions and test whether they induce a loop.
SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
- successorRegions);
+ regionInterface.getSuccessorRegions(/*index=*/std::nullopt, successorRegions);
for (RegionSuccessor ®ionEntry : successorRegions) {
if (recurse(regionEntry.getSuccessor()))
return true;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 98a60a48763ab1..d201e024380661 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -100,13 +100,16 @@ void BufferViewFlowAnalysis::build(Operation *op) {
// Query the RegionBranchOpInterface to find potential successor regions.
// Extract all entry regions and wire all initial entry successor inputs.
SmallVector<RegionSuccessor, 2> entrySuccessors;
- regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
+ regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
entrySuccessors);
for (RegionSuccessor &entrySuccessor : entrySuccessors) {
// Wire the entry region's successor arguments with the initial
// successor inputs.
registerDependencies(
- regionInterface.getEntrySuccessorOperands(entrySuccessor),
+ regionInterface.getEntrySuccessorOperands(
+ entrySuccessor.isParent()
+ ? std::optional<unsigned>()
+ : entrySuccessor.getSuccessor()->getRegionNumber()),
entrySuccessor.getSuccessorInputs());
}
@@ -115,16 +118,21 @@ void BufferViewFlowAnalysis::build(Operation *op) {
// Iterate over all successor region entries that are reachable from the
// current region.
SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(region, successorRegions);
+ regionInterface.getSuccessorRegions(region.getRegionNumber(),
+ successorRegions);
for (RegionSuccessor &successorRegion : successorRegions) {
+ // Determine the current region index (if any).
+ std::optional<unsigned> regionIndex;
+ Region *regionSuccessor = successorRegion.getSuccessor();
+ if (regionSuccessor)
+ regionIndex = regionSuccessor->getRegionNumber();
// Iterate over all immediate terminator operations and wire the
// successor inputs with the successor operands of each terminator.
for (Block &block : region)
if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
block.getTerminator()))
- registerDependencies(
- terminator.getSuccessorOperands(successorRegion),
- successorRegion.getSuccessorInputs());
+ registerDependencies(terminator.getSuccessorOperands(regionIndex),
+ successorRegion.getSuccessorInputs());
}
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9c5c322e23692b..e1b8dd62450a77 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -455,8 +455,8 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
}
void AllocaScopeOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- if (!point.isParent()) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (index) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index b573291f0460e6..63ce3b2a469627 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -266,9 +266,9 @@ void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ExecuteRegionOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
// If the predecessor is the ExecuteRegionOp, branch into the body.
- if (point.isParent()) {
+ if (!index) {
regions.push_back(RegionSuccessor(&getRegion()));
return;
}
@@ -282,8 +282,8 @@ void ExecuteRegionOp::getSuccessorRegions(
//===----------------------------------------------------------------------===//
MutableOperandRange
-ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
- assert((point.isParent() || point == getParentOp().getAfter()) &&
+ConditionOp::getMutableSuccessorOperands(std::optional<unsigned> index) {
+ assert((!index || index == getParentOp().getAfter().getRegionNumber()) &&
"condition op can only exit the loop or branch to the after"
"region");
// Pass all operands except the condition to the successor region.
@@ -553,7 +553,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
/// Return operands used when entering the region at 'index'. These operands
/// correspond to the loop iterator operands, i.e., those excluding the
/// induction variable.
-OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+OperandRange ForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
return getInitArgs();
}
@@ -562,7 +562,7 @@ OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
-void ForOp::getSuccessorRegions(RegionBranchPoint point,
+void ForOp::getSuccessorRegions(std::optional<unsigned> index,
SmallVectorImpl<RegionSuccessor> ®ions) {
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
@@ -1731,7 +1731,7 @@ void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
-void ForallOp::getSuccessorRegions(RegionBranchPoint point,
+void ForallOp::getSuccessorRegions(std::optional<unsigned> index,
SmallVectorImpl<RegionSuccessor> ®ions) {
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
@@ -2011,10 +2011,10 @@ void IfOp::print(OpAsmPrinter &p) {
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
-void IfOp::getSuccessorRegions(RegionBranchPoint point,
+void IfOp::getSuccessorRegions(std::optional<unsigned> index,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The `then` and the `else` region branch back to the parent operation.
- if (!point.isParent()) {
+ if (index) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
@@ -3042,7 +3042,7 @@ void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ParallelOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
// body.
@@ -3169,8 +3169,8 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
}
-OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBefore() &&
+OperandRange WhileOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
+ assert(index && *index == 0 &&
"WhileOp is expected to branch only to the first region");
return getInits();
@@ -3192,18 +3192,17 @@ Block::BlockArgListType WhileOp::getAfterArguments() {
return getAfterBody()->getArguments();
}
-void WhileOp::getSuccessorRegions(RegionBranchPoint point,
+void WhileOp::getSuccessorRegions(std::optional<unsigned> index,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The parent op always branches to the condition region.
- if (point.isParent()) {
+ if (!index) {
regions.emplace_back(&getBefore(), getBefore().getArguments());
return;
}
- assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
- "there are only two regions in a WhileOp");
+ assert(*index < 2 && "there are only two regions in a WhileOp");
// The body region always branches back to the condition region.
- if (point == getAfter()) {
+ if (*index == 1) {
regions.emplace_back(&getBefore(), getBefore().getArguments());
return;
}
@@ -4024,9 +4023,10 @@ Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
}
void IndexSwitchOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
+ std::optional<unsigned> index,
+ SmallVectorImpl<RegionSuccessor> &successors) {
// All regions branch back to the parent op.
- if (!point.isParent()) {
+ if (index) {
successors.emplace_back(getResults());
return;
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 78b06d9ce033f8..c52a9c5004aaaf 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -335,11 +335,11 @@ void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
void AssumingOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
// AssumingOp has unconditional control flow into the region and back to the
// parent, so return the correct RegionSuccessor purely based on the index
// being None or 0.
- if (!point.isParent()) {
+ if (index) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 518bfc3931e8d3..7bc7272b054129 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -86,25 +86,23 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
// AlternativesOp
//===----------------------------------------------------------------------===//
-OperandRange
-transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- if (!point.isParent() && getOperation()->getNumOperands() == 1)
+OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
+ std::optional<unsigned> index) {
+ if (index && getOperation()->getNumOperands() == 1)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
getOperation()->operand_end());
}
void transform::AlternativesOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
for (Region &alternative : llvm::drop_begin(
- getAlternatives(),
- point.isParent() ? 0
- : point.getRegionOrNull()->getRegionNumber() + 1)) {
+ getAlternatives(), index.has_value() ? *index + 1 : 0)) {
regions.emplace_back(&alternative, !getOperands().empty()
? alternative.getArguments()
: Block::BlockArgListType());
}
- if (!point.isParent())
+ if (index.has_value())
regions.emplace_back(getOperation()->getResults());
}
@@ -1161,24 +1159,24 @@ void transform::ForeachOp::getEffects(
}
void transform::ForeachOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
Region *bodyRegion = &getBody();
- if (point.isParent()) {
+ if (!index) {
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
return;
}
// Branch back to the region or the parent.
- assert(point == getBody() && "unexpected region index");
+ assert(*index == 0 && "unexpected region index");
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
regions.emplace_back();
}
OperandRange
-transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+transform::ForeachOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
// The iteration variable op handle is mapped to a subset (one op to be
// precise) of the payload ops of the ForeachOp operand.
- assert(point == getBody() && "unexpected region index");
+ assert(index && *index == 0 && "unexpected region index");
return getOperation()->getOperands();
}
@@ -2180,9 +2178,9 @@ void transform::SequenceOp::getEffects(
getPotentialTopLevelEffects(effects);
}
-OperandRange
-transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBody() && "unexpected region index");
+OperandRange transform::SequenceOp::getEntrySuccessorOperands(
+ std::optional<unsigned> index) {
+ assert(index && *index == 0 && "unexpected region index");
if (getOperation()->getNumOperands() > 0)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
@@ -2190,8 +2188,8 @@ transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
}
void transform::SequenceOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- if (point.isParent()) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (!index) {
Region *bodyRegion = &getBody();
regions.emplace_back(bodyRegion, getNumOperands() != 0
? bodyRegion->getArguments()
@@ -2199,7 +2197,7 @@ void transform::SequenceOp::getSuccessorRegions(
return;
}
- assert(point == getBody() && "unexpected region index");
+ assert(*index == 0 && "unexpected region index");
regions.emplace_back(getOperation()->getResults());
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 88bda3931a5a11..4e9364611b257d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5821,8 +5821,8 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
}
void WarpExecuteOnLane0Op::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- if (!point.isParent()) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (index) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index b3166155e84f93..cc90da370de693 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -84,18 +84,18 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
// RegionBranchOpInterface
//===----------------------------------------------------------------------===//
-static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
- RegionBranchPoint sourceNo,
- RegionBranchPoint succRegionNo) {
+static InFlightDiagnostic &
+printRegionEdgeName(InFlightDiagnostic &diag, std::optional<unsigned> sourceNo,
+ std::optional<unsigned> succRegionNo) {
diag << "from ";
- if (Region *region = sourceNo.getRegionOrNull())
- diag << "Region #" << region->getRegionNumber();
+ if (sourceNo)
+ diag << "Region #" << sourceNo.value();
else
diag << "parent operands";
diag << " to ";
- if (Region *region = succRegionNo.getRegionOrNull())
- diag << "Region #" << region->getRegionNumber();
+ if (succRegionNo)
+ diag << "Region #" << succRegionNo.value();
else
diag << "parent results";
return diag;
@@ -107,24 +107,28 @@ static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
/// inputs that flow from `sourceIndex' to the given region, or std::nullopt if
/// the exact type match verification is not necessary (e.g., if the Op verifies
/// the match itself).
-static LogicalResult
-verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
- function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
- getInputsTypesForRegion) {
+static LogicalResult verifyTypesAlongAllEdges(
+ Operation *op, std::optional<unsigned> sourceNo,
+ function_ref<FailureOr<TypeRange>(std::optional<unsigned>)>
+ getInputsTypesForRegion) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
SmallVector<RegionSuccessor, 2> successors;
- regionInterface.getSuccessorRegions(sourcePoint, successors);
+ regionInterface.getSuccessorRegions(sourceNo, successors);
for (RegionSuccessor &succ : successors) {
- FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
+ std::optional<unsigned> succRegionNo;
+ if (!succ.isParent())
+ succRegionNo = succ.getSuccessor()->getRegionNumber();
+
+ FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
if (failed(sourceTypes))
return failure();
TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
if (sourceTypes->size() != succInputsTypes.size()) {
InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
- return printRegionEdgeName(diag, sourcePoint, succ)
+ return printRegionEdgeName(diag, sourceNo, succRegionNo)
<< ": source has " << sourceTypes->size()
<< " operands, but target successor needs "
<< succInputsTypes.size();
@@ -136,7 +140,7 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
Type inputType = std::get<1>(typesIdx.value());
if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
- return printRegionEdgeName(diag, sourcePoint, succ)
+ return printRegionEdgeName(diag, sourceNo, succRegionNo)
<< ": source type #" << typesIdx.index() << " " << sourceType
<< " should match input type #" << typesIdx.index() << " "
<< inputType;
@@ -150,13 +154,13 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
- auto inputTypesFromParent = [&](RegionBranchPoint regionNo) -> TypeRange {
+ auto inputTypesFromParent =
+ [&](std::optional<unsigned> regionNo) -> TypeRange {
return regionInterface.getEntrySuccessorOperands(regionNo).getTypes();
};
// Verify types along control flow edges originating from the parent.
- if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(),
- inputTypesFromParent)))
+ if (failed(verifyTypesAlongAllEdges(op, std::nullopt, inputTypesFromParent)))
return failure();
auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
@@ -172,7 +176,8 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
};
// Verify types along control flow edges originating from each region.
- for (Region ®ion : op->getRegions()) {
+ for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
+ Region ®ion = op->getRegion(regionNo);
// Since there can be multiple terminators implementing the
// `RegionBranchTerminatorOpInterface`, all should have the same operand
@@ -190,7 +195,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
continue;
auto inputTypesForRegion =
- [&](RegionBranchPoint succRegionNo) -> FailureOr<TypeRange> {
+ [&](std::optional<unsigned> succRegionNo) -> FailureOr<TypeRange> {
std::optional<OperandRange> regionReturnOperands;
for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
auto terminatorOperands =
@@ -206,7 +211,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
if (!areTypesCompatible(regionReturnOperands->getTypes(),
terminatorOperands.getTypes())) {
InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
- return printRegionEdgeName(diag, region, succRegionNo)
+ return printRegionEdgeName(diag, regionNo, succRegionNo)
<< " operands mismatch between return-like terminators";
}
}
@@ -215,7 +220,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
return TypeRange(regionReturnOperands->getTypes());
};
- if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
+ if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesForRegion)))
return failure();
}
@@ -232,24 +237,24 @@ static bool isRegionReachable(Region *begin, Region *r) {
visited[begin->getRegionNumber()] = true;
// Retrieve all successors of the region and enqueue them in the worklist.
- SmallVector<Region *> worklist;
- auto enqueueAllSuccessors = [&](Region *region) {
+ SmallVector<unsigned> worklist;
+ auto enqueueAllSuccessors = [&](unsigned index) {
SmallVector<RegionSuccessor> successors;
- op.getSuccessorRegions(region, successors);
+ op.getSuccessorRegions(index, successors);
for (RegionSuccessor successor : successors)
if (!successor.isParent())
- worklist.push_back(successor.getSuccessor());
+ worklist.push_back(successor.getSuccessor()->getRegionNumber());
};
- enqueueAllSuccessors(begin);
+ enqueueAllSuccessors(begin->getRegionNumber());
// Process all regions in the worklist via DFS.
while (!worklist.empty()) {
- Region *nextRegion = worklist.pop_back_val();
- if (nextRegion == r)
+ unsigned nextRegion = worklist.pop_back_val();
+ if (nextRegion == r->getRegionNumber())
return true;
- if (visited[nextRegion->getRegionNumber()])
+ if (visited[nextRegion])
continue;
- visited[nextRegion->getRegionNumber()] = true;
+ visited[nextRegion] = true;
enqueueAllSuccessors(nextRegion);
}
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 19a84db34dce02..ce19dc667f009d 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -316,11 +316,15 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// Return the successors of `region` if the latter is not null. Else return
// the successors of `regionBranchOp`.
auto getSuccessors = [&](Region *region = nullptr) {
- auto point = region ? region : RegionBranchPoint::parent();
+ std::optional<unsigned> index =
+ region ? std::optional(region->getRegionNumber()) : std::nullopt;
SmallVector<Attribute> operandAttributes(regionBranchOp->getNumOperands(),
nullptr);
SmallVector<RegionSuccessor> successors;
- regionBranchOp.getSuccessorRegions(point, successors);
+ if (!index)
+ regionBranchOp.getEntrySuccessorRegions(operandAttributes, successors);
+ else
+ regionBranchOp.getSuccessorRegions(index, successors);
return successors;
};
@@ -329,10 +333,14 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// forwarded to `successor`.
auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
Operation *terminator = nullptr) {
+ Region *successorRegion = successor.getSuccessor();
+ std::optional<unsigned> index =
+ successorRegion ? std::optional(successorRegion->getRegionNumber())
+ : std::nullopt;
OperandRange operands =
terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
- .getSuccessorOperands(successor)
- : regionBranchOp.getEntrySuccessorOperands(successor);
+ .getSuccessorOperands(index)
+ : regionBranchOp.getEntrySuccessorOperands(index);
SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
return opOperands;
};
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
index 8bfd01d828060a..a33b523d5d192f 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
@@ -60,8 +60,8 @@ class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
NextAccess *before) override;
void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
- RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo,
+ std::optional<unsigned> regionFrom,
+ std::optional<unsigned> regionTo,
const NextAccess &after,
NextAccess *before) override;
@@ -124,15 +124,15 @@ void NextAccessAnalysis::visitCallControlFlowTransfer(
}
void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
- RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) {
+ RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
+ std::optional<unsigned> regionTo, const NextAccess &after,
+ NextAccess *before) {
auto testStoreWithARegion =
dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
if (testStoreWithARegion &&
- ((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) ||
- (regionFrom.isParent() &&
- testStoreWithARegion.getStoreBeforeRegion()))) {
+ ((!regionTo && !testStoreWithARegion.getStoreBeforeRegion()) ||
+ (!regionFrom && testStoreWithARegion.getStoreBeforeRegion()))) {
visitOperation(branch, static_cast<const NextAccess &>(after),
static_cast<NextAccess *>(before));
} else {
@@ -219,7 +219,7 @@ struct TestNextAccessPass
SmallVector<Attribute> entryPointNextAccess;
SmallVector<RegionSuccessor> regionSuccessors;
- iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors);
+ iface.getSuccessorRegions(std::nullopt, regionSuccessors);
for (const RegionSuccessor &successor : regionSuccessors) {
if (!successor.getSuccessor() || successor.getSuccessor()->empty())
continue;
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 34ed7a1a66fe33..57a6ab387281dc 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -931,17 +931,17 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
parser.getCurrentLocation(), result.operands);
}
-OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
- "invalid region index");
+OperandRange
+RegionIfOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
+ assert(index && *index < 2 && "invalid region index");
return getOperands();
}
void RegionIfOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
// We always branch to the join region.
- if (!point.isParent()) {
- if (point != getJoinRegion())
+ if (index.has_value()) {
+ if (index.value() < 2)
regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
else
regions.push_back(RegionSuccessor(getResults()));
@@ -964,11 +964,11 @@ void RegionIfOp::getRegionInvocationBounds(
// AnyCondOp
//===----------------------------------------------------------------------===//
-void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
+void AnyCondOp::getSuccessorRegions(std::optional<unsigned> index,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The parent op branches into the only region, and the region branches back
// to the parent op.
- if (point.isParent())
+ if (!index)
regions.emplace_back(&getRegion());
else
regions.emplace_back(getResults());
@@ -985,16 +985,17 @@ void AnyCondOp::getRegionInvocationBounds(
//===----------------------------------------------------------------------===//
void LoopBlockOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
regions.emplace_back(&getBody(), getBody().getArguments());
- if (point.isParent())
+ if (!index)
return;
regions.emplace_back((*this)->getResults());
}
-OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBody());
+OperandRange
+LoopBlockOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
+ assert(index == 0);
return getInitMutable();
}
@@ -1002,9 +1003,10 @@ OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
// LoopBlockTerminatorOp
//===----------------------------------------------------------------------===//
-MutableOperandRange
-LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
- if (point.isParent())
+MutableOperandRange LoopBlockTerminatorOp::getMutableSuccessorOperands(
+ std::optional<unsigned> index) {
+ assert(!index || index == 0);
+ if (!index)
return getExitArgMutable();
return getNextIterArgMutable();
}
@@ -1311,11 +1313,12 @@ MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
}
void TestStoreWithARegion::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- if (point.isParent())
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (!index) {
regions.emplace_back(&getBody(), getBody().front().getArguments());
- else
+ } else {
regions.emplace_back();
+ }
}
LogicalResult
diff --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
index f1aae15393fd3f..a507baa6445d97 100644
--- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
@@ -37,7 +37,7 @@ struct MutuallyExclusiveRegionsOp
}
// Regions have no successors.
- void getSuccessorRegions(RegionBranchPoint point,
+ void getSuccessorRegions(std::optional<unsigned> index,
SmallVectorImpl<RegionSuccessor> ®ions) {}
};
@@ -51,13 +51,14 @@ struct LoopRegionsOp
static StringRef getOperationName() { return "cftest.loop_regions_op"; }
- void getSuccessorRegions(RegionBranchPoint point,
+ void getSuccessorRegions(std::optional<unsigned> index,
SmallVectorImpl<RegionSuccessor> ®ions) {
- if (Region *region = point.getRegionOrNull()) {
- if (point == (*this)->getRegion(1))
+ if (index) {
+ if (*index == 1)
// This region also branches back to the parent.
regions.push_back(RegionSuccessor());
- regions.push_back(RegionSuccessor(region));
+ regions.push_back(
+ RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions)));
}
}
};
@@ -73,11 +74,11 @@ struct DoubleLoopRegionsOp
return "cftest.double_loop_regions_op";
}
- void getSuccessorRegions(RegionBranchPoint point,
+ void getSuccessorRegions(std::optional<unsigned> index,
SmallVectorImpl<RegionSuccessor> ®ions) {
- if (Region *region = point.getRegionOrNull()) {
+ if (index.has_value()) {
regions.push_back(RegionSuccessor());
- regions.push_back(RegionSuccessor(region));
+ regions.push_back(RegionSuccessor(&getOperation()->getRegion(*index)));
}
}
};
@@ -91,9 +92,9 @@ struct SequentialRegionsOp
static StringRef getOperationName() { return "cftest.sequential_regions_op"; }
// Region 0 has Region 1 as a successor.
- void getSuccessorRegions(RegionBranchPoint point,
+ void getSuccessorRegions(std::optional<unsigned> index,
SmallVectorImpl<RegionSuccessor> ®ions) {
- if (point == (*this)->getRegion(0)) {
+ if (index == 0u) {
Operation *thisOp = this->getOperation();
regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
}
More information about the Mlir-commits
mailing list