[flang-commits] [flang] 138df29 - [mlir] Revamp `RegionBranchOpInterface` successor mechanism
Markus Böck via flang-commits
flang-commits at lists.llvm.org
Thu Aug 10 01:35:45 PDT 2023
Author: Markus Böck
Date: 2023-08-10T10:27:27+02:00
New Revision: 138df298208a095dc9bb9e5d1e3c67231b0abd77
URL: https://github.com/llvm/llvm-project/commit/138df298208a095dc9bb9e5d1e3c67231b0abd77
DIFF: https://github.com/llvm/llvm-project/commit/138df298208a095dc9bb9e5d1e3c67231b0abd77.diff
LOG: [mlir] Revamp `RegionBranchOpInterface` successor mechanism
The `RegionBranchOpInterface` had a few fundamental issues caused by the API design of `getSuccessorRegions`.
It always required passing values for the `operands` parameter. This is problematic as the operands parameter actually changes meaning depending on which predecessor `index` is referring to. If coming from a region, you'd have to find a `RegionBranchTerminatorOpInterface` in that region, get its operand count, and then create a `SmallVector` of that size.
This is not only inconvenient, but also error-prone, which has lead to a bug in the implementation of a previously existing `getSuccessorRegions` overload.
Additionally, this made the method dual-use, trying to serve two different use-cases: 1) Trying to determine possible control flow edges between regions and 2) Trying to determine the region being branched to based on constant operands.
This patch fixes these issues by changing the interface methods and adding new ones:
* The `operands` argument of `getSuccessorRegions` has been removed. The method is now only responsible for returning possible control flow edges between regions.
* An optional `getEntrySuccessorRegions` method has been added. This is used to determine which regions are branched to from the parent op based on constant operands of the parent op. By default, it calls `getSuccessorRegions`. This is analogous to `getSuccessorForOperands` from `BranchOpInterface`.
* Add `getSuccessorRegions` to `RegionBranchTerminatorOpInterface`. This is used to get the possible successors of the terminator based on constant operands. By default, it calls the containing `RegionBranchOpInterface`s `getSuccessorRegions` method.
* `getSuccessorEntryOperands` was renamed to `getEntrySuccessorOperands` for consistency.
Differential Revision: https://reviews.llvm.org/D157506
Added:
Modified:
flang/include/flang/Optimizer/Dialect/FIROps.td
flang/lib/Optimizer/Dialect/FIROps.cpp
mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Interfaces/ControlFlowInterfaces.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index f07e8009cf2c24..945af1d61d356e 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2147,7 +2147,7 @@ def fir_DoLoopOp : region_Op<"do_loop",
}
def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
- "getRegionInvocationBounds"]>, RecursiveMemoryEffects,
+ "getRegionInvocationBounds", "getEntrySuccessorRegions"]>, RecursiveMemoryEffects,
NoRegionArguments]> {
let summary = "if-then-else conditional operation";
let description = [{
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 11fc8473d1734c..bbe06577c27e7b 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3461,15 +3461,13 @@ void fir::IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
}
}
-// These 2 functions copied from scf.if implementation.
+// These 3 functions copied from scf.if implementation.
/// Given the region at `index`, or the parent operation if `index` is None,
/// return the successor regions. These are the regions that may be selected
-/// during the flow of control. `operands` is a set of optional attributes that
-/// correspond to a constant value for each operand, or null if that operand is
-/// not a constant.
+/// during the flow of control.
void fir::IfOp::getSuccessorRegions(
- std::optional<unsigned> index, llvm::ArrayRef<mlir::Attribute> operands,
+ std::optional<unsigned> index,
llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) {
// The `then` and the `else` region branch back to the parent operation.
if (index) {
@@ -3477,27 +3475,33 @@ void fir::IfOp::getSuccessorRegions(
return;
}
+ // Don't consider the else region if it is empty.
+ regions.push_back(mlir::RegionSuccessor(&getThenRegion()));
+
// Don't consider the else region if it is empty.
mlir::Region *elseRegion = &this->getElseRegion();
if (elseRegion->empty())
- elseRegion = nullptr;
+ regions.push_back(mlir::RegionSuccessor());
+ else
+ regions.push_back(mlir::RegionSuccessor(elseRegion));
+}
- // Otherwise, the successor is dependent on the condition.
- bool condition;
- if (auto condAttr = operands.front().dyn_cast_or_null<mlir::IntegerAttr>()) {
- condition = condAttr.getValue().isOne();
- } else {
- // If the condition isn't constant, both regions may be executed.
- regions.push_back(mlir::RegionSuccessor(&getThenRegion()));
- // If the else region does not exist, it is not a viable successor.
- if (elseRegion)
- regions.push_back(mlir::RegionSuccessor(elseRegion));
- return;
+void fir::IfOp::getEntrySuccessorRegions(
+ llvm::ArrayRef<mlir::Attribute> operands,
+ llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) {
+ FoldAdaptor adaptor(operands);
+ auto boolAttr =
+ mlir::dyn_cast_or_null<mlir::BoolAttr>(adaptor.getCondition());
+ if (!boolAttr || boolAttr.getValue())
+ regions.emplace_back(&getThenRegion());
+
+ // If the else region is empty, execution continues after the parent op.
+ if (!boolAttr || !boolAttr.getValue()) {
+ if (!getElseRegion().empty())
+ regions.emplace_back(&getElseRegion());
+ else
+ regions.emplace_back(getResults());
}
-
- // Add the successor regions using the condition.
- regions.push_back(
- mlir::RegionSuccessor(condition ? &getThenRegion() : elseRegion));
}
void fir::IfOp::getRegionInvocationBounds(
diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
index 035566fbc15751..7a6fea8326a58b 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
@@ -26,6 +26,7 @@ class CallOpInterface;
class CallableOpInterface;
class BranchOpInterface;
class RegionBranchOpInterface;
+class RegionBranchTerminatorOpInterface;
namespace dataflow {
@@ -207,7 +208,8 @@ class DeadCodeAnalysis : public DataFlowAnalysis {
/// Visit the given terminator operation that exits a region under an
/// operation with control-flow semantics. These are terminators with no CFG
/// successors.
- void visitRegionTerminator(Operation *op, RegionBranchOpInterface branch);
+ void visitRegionTerminator(RegionBranchTerminatorOpInterface op,
+ RegionBranchOpInterface branch);
/// Visit the given terminator operation that exits a callable region. These
/// are terminators with no CFG successors.
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 0ba7782a2645b6..628586d33bf4fd 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -123,7 +123,7 @@ def AffineForOp : Affine_Op<"for",
["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
"getSingleUpperBound"]>,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getSuccessorEntryOperands"]>]> {
+ ["getEntrySuccessorOperands"]>]> {
let summary = "for operation";
let description = [{
Syntax:
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index 54ad3c63189c8d..52117769d18045 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -35,7 +35,7 @@ class Async_Op<string mnemonic, list<Trait> traits = []> :
def Async_ExecuteOp :
Async_Op<"execute", [SingleBlockImplicitTerminator<"YieldOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getSuccessorEntryOperands",
+ ["getEntrySuccessorOperands",
"areTypesCompatible"]>,
AttrSizedOperandSegments,
AutomaticAllocationScope]> {
@@ -312,8 +312,7 @@ def Async_ReturnOp : Async_Op<"return",
def Async_YieldOp :
Async_Op<"yield", [
- HasParent<"ExecuteOp">, Pure, Terminator,
- DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>]> {
+ HasParent<"ExecuteOp">, Pure, Terminator, ReturnLike]> {
let summary = "terminator for Async execute operation";
let description = [{
The `async.yield` is a special terminator operation for the block inside
@@ -322,7 +321,6 @@ def Async_YieldOp :
let arguments = (ins Variadic<AnyType>:$operands);
let assemblyFormat = "($operands^ `:` type($operands))? attr-dict";
- let hasVerifier = 1;
}
def Async_AwaitOp : Async_Op<"await"> {
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index db7b41afb7ed37..dd9a350d64561c 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -39,7 +39,8 @@ class SCF_Op<string mnemonic, list<Trait> traits = []> :
def ConditionOp : SCF_Op<"condition", [
HasParent<"WhileOp">,
- DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>,
+ DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface,
+ ["getSuccessorRegions"]>,
Pure,
Terminator
]> {
@@ -124,7 +125,8 @@ def ForOp : SCF_Op<"for",
"getSingleUpperBound", "promoteIfSingleIteration"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
ConditionallySpeculatable,
- DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getEntrySuccessorOperands"]>,
SingleBlockImplicitTerminator<"scf::YieldOp">,
RecursiveMemoryEffects]> {
let summary = "for operation";
@@ -335,12 +337,6 @@ def ForOp : SCF_Op<"for",
getNumControlOperands() + opResult.getResultNumber());
}
- /// Return operands used when entering the region at 'index'. These operands
- /// correspond to the loop iterator operands, i.e., those exclusing the
- /// induction variable. LoopOp only has one region, so 0 is the only valid
- /// value for `index`.
- OperandRange getSuccessorEntryOperands(std::optional<unsigned> index);
-
/// Returns the step as an `APInt` if it is constant.
std::optional<APInt> getConstantStep();
@@ -712,7 +708,8 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [
//===----------------------------------------------------------------------===//
def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
- "getNumRegionInvocations", "getRegionInvocationBounds"]>,
+ "getNumRegionInvocations", "getRegionInvocationBounds",
+ "getEntrySuccessorRegions"]>,
InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">,
RecursiveMemoryEffects, NoRegionArguments]> {
let summary = "if-then-else operation";
@@ -978,7 +975,8 @@ def ReduceReturnOp :
//===----------------------------------------------------------------------===//
def WhileOp : SCF_Op<"while",
- [DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ [DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getEntrySuccessorOperands"]>,
RecursiveMemoryEffects]> {
let summary = "a generic 'while' loop";
let description = [{
@@ -1108,7 +1106,6 @@ def WhileOp : SCF_Op<"while",
using BodyBuilderFn =
function_ref<void(OpBuilder &, Location, ValueRange)>;
- OperandRange getSuccessorEntryOperands(std::optional<unsigned> index);
ConditionOp getConditionOp();
YieldOp getYieldOp();
Block::BlockArgListType getBeforeArguments();
@@ -1127,7 +1124,8 @@ def WhileOp : SCF_Op<"while",
def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"scf::YieldOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getRegionInvocationBounds"]>]> {
+ ["getRegionInvocationBounds",
+ "getEntrySuccessorRegions"]>]> {
let summary = "switch-case operation on an index argument";
let description = [{
The `scf.index_switch` is a control-flow operation that branches to one of
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index ca22121e62766e..f21e4bf6127e11 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -25,7 +25,7 @@ include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
def AlternativesOp : TransformDialectOp<"alternatives",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getSuccessorEntryOperands", "getSuccessorRegions",
+ ["getEntrySuccessorOperands", "getSuccessorRegions",
"getRegionInvocationBounds"]>,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -507,7 +507,7 @@ def ForeachOp : TransformDialectOp<"foreach",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<RegionBranchOpInterface, [
- "getSuccessorRegions", "getSuccessorEntryOperands"]>,
+ "getSuccessorRegions", "getEntrySuccessorOperands"]>,
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
]> {
let summary = "Executes the body for each payload op";
@@ -1016,7 +1016,7 @@ def SelectOp : TransformDialectOp<"select",
def SequenceOp : TransformDialectOp<"sequence",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getSuccessorEntryOperands", "getSuccessorRegions",
+ ["getEntrySuccessorOperands", "getSuccessorRegions",
"getRegionInvocationBounds"]>,
MatchOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>,
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index b0cea5c5565c4e..132bd6d53d923a 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -134,36 +134,50 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
InterfaceMethod<[{
Returns the operands of this operation used as the entry arguments when
entering the region at `index`, which was specified as a successor of
- this operation by `getSuccessorRegions`, or the operands forwarded to
- the operation's results when it branches back to itself. These operands
+ this operation by `getEntrySuccessorRegions`, or the operands forwarded
+ to the operation's results when it branches back to itself. These operands
should correspond 1-1 with the successor inputs specified in
- `getSuccessorRegions`.
+ `getEntrySuccessorRegions`.
}],
- "::mlir::OperandRange", "getSuccessorEntryOperands",
+ "::mlir::OperandRange", "getEntrySuccessorOperands",
(ins "::std::optional<unsigned>":$index), [{}],
/*defaultImplementation=*/[{
auto operandEnd = this->getOperation()->operand_end();
return ::mlir::OperandRange(operandEnd, operandEnd);
}]
>,
+ InterfaceMethod<[{
+ Returns the viable region successors that are branched to when first
+ executing the op.
+ Unlike `getSuccessorRegions`, this method also passes along the
+ constant operands of this op. Based on these,
diff erent region
+ successors can be determined.
+ `operands` contains an entry for every operand of the implementing
+ op with a null attribute if the operand has no constant value or
+ the corresponding attribute if it is a constant.
+
+ By default, simply dispatches to `getSuccessorRegions`.
+ }],
+ "void", "getEntrySuccessorRegions",
+ (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
+ "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
+ [{}], [{
+ $_op.getSuccessorRegions(std::nullopt, regions);
+ }]
+ >,
InterfaceMethod<[{
Returns the viable successors of a region at `index`, or the possible
successors when branching from the parent op if `index` is None. These
- are the regions that may be selected during the flow of control. If
- `index` is None, `operands` is a set of optional attributes that
- either correspond to a constant value for each operand of this
- operation, or null if that operand is not a constant. If `index` is
- valid, `operands` corresponds to the entry values of the region at
- `index`. The parent operation, i.e. a null `index`, may specify itself
- as successor, which indicates that the control flow may not enter any
- region at all. This method allows for describing which
- regions may be executed when entering an operation, and which regions
- are executed after having executed another region of the parent op. The
- successor region must be non-empty.
+ are the regions that may be selected during the flow of control. The
+ parent operation, i.e. a null `index`, may specify itself as successor,
+ which indicates that the control flow may not enter any region at all.
+ This method allows for describing which regions may be executed when
+ entering an operation, and which regions are executed after having
+ executed another region of the parent op. The successor region must be
+ non-empty.
}],
"void", "getSuccessorRegions",
(ins "::std::optional<unsigned>":$index,
- "::llvm::ArrayRef<::mlir::Attribute>":$operands,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
>,
InterfaceMethod<[{
@@ -208,10 +222,6 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
let verifyWithRegions = 1;
let extraClassDeclaration = [{
- /// Convenience helper in case none of the operands is known.
- void getSuccessorRegions(std::optional<unsigned> index,
- SmallVectorImpl<RegionSuccessor> ®ions);
-
/// Return `true` if control flow originating from the given region may
/// eventually branch back to the same region. (Maybe after passing through
/// other regions.)
@@ -243,17 +253,26 @@ def RegionBranchTerminatorOpInterface :
(ins "::std::optional<unsigned>":$index)
>,
InterfaceMethod<[{
- Returns a range of operands that are semantically "returned" by passing
- them to the region successor given by `index`. If `index` is None, this
- function returns the operands that are passed as a result to the parent
- operation.
+ Returns the viable region successors that are branched to after this
+ terminator based on the given constant operands.
+
+ `operands` contains an entry for every operand of the
+ implementing op with a null attribute if the operand has no constant
+ value or the corresponding attribute if it is a constant.
+
+ Default implementation simply dispatches to the parent
+ `RegionBranchOpInterface`'s `getSuccessorRegions` implementation.
}],
- "::mlir::OperandRange", "getSuccessorOperands",
- (ins "::std::optional<unsigned>":$index), [{}],
- /*defaultImplementation=*/[{
- return $_op.getMutableSuccessorOperands(index);
+ "void", "getSuccessorRegions",
+ (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
+ "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}],
+ [{
+ ::mlir::Operation *op = $_op;
+ ::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
+ .getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
+ regions);
}]
- >
+ >,
];
let verify = [{
@@ -265,6 +284,16 @@ def RegionBranchTerminatorOpInterface :
"expected operation to have zero successors");
return success();
}];
+
+ let extraClassDeclaration = [{
+ // Returns a range of operands that are semantically "returned" by passing
+ // them to the region successor given by `index`. If `index` is None, this
+ // function returns the operands that are passed as a result to the parent
+ // operation.
+ ::mlir::OperandRange getSuccessorOperands(std::optional<unsigned> index) {
+ return getMutableSuccessorOperands(index);
+ }
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index 7d893f7b918ab4..970e68bc258649 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -83,7 +83,7 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
if (std::optional<unsigned> operandIndex =
getOperandIndexIfPred(/*predIndex=*/std::nullopt)) {
collectUnderlyingAddressValues(
- branch.getSuccessorEntryOperands(regionIndex)[*operandIndex], maxDepth,
+ branch.getEntrySuccessorOperands(regionIndex)[*operandIndex], maxDepth,
visited, output);
}
// Check branches from each child region.
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 30a285068a0748..d423d37b9770c6 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -259,7 +259,8 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
if (isRegionOrCallableReturn(op)) {
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
// Visit the exiting terminator of a region.
- visitRegionTerminator(op, branch);
+ visitRegionTerminator(cast<RegionBranchTerminatorOpInterface>(op),
+ branch);
} else if (auto callable =
dyn_cast<CallableOpInterface>(op->getParentOp())) {
// Visit the exiting terminator of a callable.
@@ -361,7 +362,7 @@ void DeadCodeAnalysis::visitRegionBranchOperation(
return;
SmallVector<RegionSuccessor> successors;
- branch.getSuccessorRegions(/*index=*/{}, *operands, successors);
+ branch.getEntrySuccessorRegions(*operands, successors);
for (const RegionSuccessor &successor : successors) {
// The successor can be either an entry block or the parent operation.
ProgramPoint point = successor.getSuccessor()
@@ -378,15 +379,14 @@ void DeadCodeAnalysis::visitRegionBranchOperation(
}
}
-void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
- RegionBranchOpInterface branch) {
+void DeadCodeAnalysis::visitRegionTerminator(
+ RegionBranchTerminatorOpInterface op, RegionBranchOpInterface branch) {
std::optional<SmallVector<Attribute>> operands = getOperandValues(op);
if (!operands)
return;
SmallVector<RegionSuccessor> successors;
- branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
- *operands, successors);
+ op.getSuccessorRegions(*operands, successors);
// Mark successor region entry blocks as executable and add this op to the
// list of predecessors.
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index abe754a60cfbda..f8bd754092023d 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -224,7 +224,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
// Check if the predecessor is the parent op.
if (op == branch) {
- operands = branch.getSuccessorEntryOperands(successorIndex);
+ operands = branch.getEntrySuccessorOperands(successorIndex);
// Otherwise, try to deduce the operands from a region return-like op.
} else if (auto regionTerminator =
dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
@@ -479,7 +479,7 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
Operation *op = branch.getOperation();
SmallVector<RegionSuccessor> successors;
SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
- branch.getSuccessorRegions(/*index=*/{}, operands, successors);
+ branch.getEntrySuccessorRegions(operands, successors);
// All operands not forwarded to any successor. This set can be non-contiguous
// in the presence of multiple successors.
@@ -488,8 +488,8 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
for (RegionSuccessor &successor : successors) {
Region *region = successor.getSuccessor();
OperandRange operands =
- region ? branch.getSuccessorEntryOperands(region->getRegionNumber())
- : branch.getSuccessorEntryOperands({});
+ region ? branch.getEntrySuccessorOperands(region->getRegionNumber())
+ : branch.getEntrySuccessorOperands({});
MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
ValueRange inputs = successor.getSuccessorInputs();
for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
@@ -516,8 +516,7 @@ void AbstractSparseBackwardDataFlowAnalysis::
SmallVector<Attribute> operandAttributes(terminator->getNumOperands(),
nullptr);
SmallVector<RegionSuccessor> successors;
- branch.getSuccessorRegions(terminator->getParentRegion()->getRegionNumber(),
- operandAttributes, successors);
+ terminator.getSuccessorRegions(operandAttributes, successors);
// All operands not forwarded to any successor. This set can be
// non-contiguous in the presence of multiple successors.
BitVector unaccounted(terminator->getNumOperands(), true);
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 7c14ea75d22539..bb4aaee21d019e 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2380,7 +2380,7 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// induction variable. AffineForOp only has one region, so zero is the only
/// valid value for `index`.
OperandRange
-AffineForOp::getSuccessorEntryOperands(std::optional<unsigned> index) {
+AffineForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
assert((!index || *index == 0) && "invalid region index");
// The initial operands map to the loop arguments after the induction
@@ -2394,8 +2394,7 @@ AffineForOp::getSuccessorEntryOperands(std::optional<unsigned> index) {
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void AffineForOp::getSuccessorRegions(
- std::optional<unsigned> index, ArrayRef<Attribute> operands,
- SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
assert((!index.has_value() || index.value() == 0) && "expected loop region");
// The loop may typically branch back to its body or to the parent operation.
// If the predecessor is the parent op and the trip count is known to be at
@@ -2860,8 +2859,7 @@ struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
/// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
/// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
void AffineIfOp::getSuccessorRegions(
- std::optional<unsigned> index, ArrayRef<Attribute> operands,
- SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
// If the predecessor is an AffineIfOp, then branching into both `then` and
// `else` region is valid.
if (!index.has_value()) {
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index abe6670c7f855f..9b4fb81990c169 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -32,31 +32,6 @@ void AsyncDialect::initialize() {
>();
}
-//===----------------------------------------------------------------------===//
-// YieldOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult YieldOp::verify() {
- // Get the underlying value types from async values returned from the
- // parent `async.execute` operation.
- auto executeOp = (*this)->getParentOfType<ExecuteOp>();
- auto types =
- llvm::map_range(executeOp.getBodyResults(), [](const OpResult &result) {
- return llvm::cast<ValueType>(result.getType()).getValueType();
- });
-
- if (getOperandTypes() != types)
- return emitOpError("operand types do not match the types returned from "
- "the parent ExecuteOp");
-
- return success();
-}
-
-MutableOperandRange
-YieldOp::getMutableSuccessorOperands(std::optional<unsigned> index) {
- return getOperandsMutable();
-}
-
//===----------------------------------------------------------------------===//
/// ExecuteOp
//===----------------------------------------------------------------------===//
@@ -64,7 +39,7 @@ YieldOp::getMutableSuccessorOperands(std::optional<unsigned> index) {
constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
OperandRange
-ExecuteOp::getSuccessorEntryOperands(std::optional<unsigned> index) {
+ExecuteOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
assert(index && *index == 0 && "invalid region index");
return getBodyOperands();
}
@@ -79,7 +54,6 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
}
void ExecuteOp::getSuccessorRegions(std::optional<unsigned> index,
- ArrayRef<Attribute>,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The `body` region branch back to the parent operation.
if (index) {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index 40d40cf46f0b6f..b468a6bb3f9097 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -384,7 +384,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
// Determine the actual operand to introduce a clone for and rewire the
// operand to point to the clone instead.
auto operands =
- regionInterface.getSuccessorEntryOperands(argRegion->getRegionNumber());
+ regionInterface.getEntrySuccessorOperands(argRegion->getRegionNumber());
size_t operandIndex =
llvm::find(it->getSuccessorInputs(), blockArg).getIndex() +
operands.getBeginOperandIndex();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 39b00dd2956bc0..d201e024380661 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -106,7 +106,7 @@ void BufferViewFlowAnalysis::build(Operation *op) {
// Wire the entry region's successor arguments with the initial
// successor inputs.
registerDependencies(
- regionInterface.getSuccessorEntryOperands(
+ regionInterface.getEntrySuccessorOperands(
entrySuccessor.isParent()
? std::optional<unsigned>()
: entrySuccessor.getSuccessor()->getRegionNumber()),
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 5f35adf0ddaab1..b2909962027de2 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -461,8 +461,7 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
}
void AllocaScopeOp::getSuccessorRegions(
- std::optional<unsigned> index, ArrayRef<Attribute> operands,
- SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
if (index) {
regions.push_back(RegionSuccessor(getResults()));
return;
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 37e22bbc089f01..2cc9e2c895666e 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -266,8 +266,7 @@ void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ExecuteRegionOp::getSuccessorRegions(
- std::optional<unsigned> index, ArrayRef<Attribute> operands,
- SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
// If the predecessor is the ExecuteRegionOp, branch into the body.
if (!index) {
regions.push_back(RegionSuccessor(&getRegion()));
@@ -288,6 +287,22 @@ ConditionOp::getMutableSuccessorOperands(std::optional<unsigned> index) {
return getArgsMutable();
}
+void ConditionOp::getSuccessorRegions(
+ ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> ®ions) {
+ FoldAdaptor adaptor(operands);
+
+ WhileOp whileOp = getParentOp();
+
+ // Condition can either lead to the after region or back to the parent op
+ // depending on whether the condition is true or not.
+ auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
+ if (!boolAttr || boolAttr.getValue())
+ regions.emplace_back(&whileOp.getAfter(),
+ whileOp.getAfter().getArguments());
+ if (!boolAttr || !boolAttr.getValue())
+ regions.emplace_back(whileOp.getResults());
+}
+
//===----------------------------------------------------------------------===//
// ForOp
//===----------------------------------------------------------------------===//
@@ -535,7 +550,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
/// Return operands used when entering the region at 'index'. These operands
/// correspond to the loop iterator operands, i.e., those excluding the
/// induction variable.
-OperandRange ForOp::getSuccessorEntryOperands(std::optional<unsigned> index) {
+OperandRange ForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
return getInitArgs();
}
@@ -545,7 +560,6 @@ OperandRange ForOp::getSuccessorEntryOperands(std::optional<unsigned> index) {
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ForOp::getSuccessorRegions(std::optional<unsigned> index,
- ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> ®ions) {
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
@@ -1715,7 +1729,6 @@ void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ForallOp::getSuccessorRegions(std::optional<unsigned> index,
- ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> ®ions) {
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
@@ -1996,7 +2009,6 @@ void IfOp::print(OpAsmPrinter &p) {
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void IfOp::getSuccessorRegions(std::optional<unsigned> index,
- ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The `then` and the `else` region branch back to the parent operation.
if (index) {
@@ -2004,29 +2016,30 @@ void IfOp::getSuccessorRegions(std::optional<unsigned> index,
return;
}
+ regions.push_back(RegionSuccessor(&getThenRegion()));
+
// Don't consider the else region if it is empty.
Region *elseRegion = &this->getElseRegion();
if (elseRegion->empty())
- elseRegion = nullptr;
-
- // Otherwise, the successor is dependent on the condition.
- bool condition;
- if (auto condAttr = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
- condition = condAttr.getValue().isOne();
- } else {
- // If the condition isn't constant, both regions may be executed.
- regions.push_back(RegionSuccessor(&getThenRegion()));
- // If the else region does not exist, it is not a viable successor, so the
- // control will go back to this operation instead.
- if (elseRegion)
- regions.push_back(RegionSuccessor(elseRegion));
+ regions.push_back(RegionSuccessor());
+ else
+ regions.push_back(RegionSuccessor(elseRegion));
+}
+
+void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ FoldAdaptor adaptor(operands);
+ auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
+ if (!boolAttr || boolAttr.getValue())
+ regions.emplace_back(&getThenRegion());
+
+ // If the else region is empty, execution continues after the parent op.
+ if (!boolAttr || !boolAttr.getValue()) {
+ if (!getElseRegion().empty())
+ regions.emplace_back(&getElseRegion());
else
- regions.push_back(RegionSuccessor());
- return;
+ regions.emplace_back(getResults());
}
-
- // Add the successor regions using the condition.
- regions.push_back(RegionSuccessor(condition ? &getThenRegion() : elseRegion));
}
LogicalResult IfOp::fold(FoldAdaptor adaptor,
@@ -3026,8 +3039,7 @@ void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ParallelOp::getSuccessorRegions(
- std::optional<unsigned> index, ArrayRef<Attribute> operands,
- SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
// body.
@@ -3154,7 +3166,7 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
}
-OperandRange WhileOp::getSuccessorEntryOperands(std::optional<unsigned> index) {
+OperandRange WhileOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
assert(index && *index == 0 &&
"WhileOp is expected to branch only to the first region");
@@ -3178,7 +3190,6 @@ Block::BlockArgListType WhileOp::getAfterArguments() {
}
void WhileOp::getSuccessorRegions(std::optional<unsigned> index,
- ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The parent op always branches to the condition region.
if (!index) {
@@ -3193,13 +3204,8 @@ void WhileOp::getSuccessorRegions(std::optional<unsigned> index,
return;
}
- // Try to narrow the successor to the condition region.
- assert(!operands.empty() && "expected at least one operand");
- auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0]);
- if (!cond || !cond.getValue())
- regions.emplace_back(getResults());
- if (!cond || cond.getValue())
- regions.emplace_back(&getAfter(), getAfter().getArguments());
+ regions.emplace_back(getResults());
+ regions.emplace_back(&getAfter(), getAfter().getArguments());
}
/// Parses a `while` op.
@@ -4016,7 +4022,7 @@ Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
}
void IndexSwitchOp::getSuccessorRegions(
- std::optional<unsigned> index, ArrayRef<Attribute> operands,
+ std::optional<unsigned> index,
SmallVectorImpl<RegionSuccessor> &successors) {
// All regions branch back to the parent op.
if (index) {
@@ -4024,19 +4030,25 @@ void IndexSwitchOp::getSuccessorRegions(
return;
}
+ llvm::copy(getRegions(), std::back_inserter(successors));
+}
+
+void IndexSwitchOp::getEntrySuccessorRegions(
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> &successors) {
+ FoldAdaptor adaptor(operands);
+
// If a constant was not provided, all regions are possible successors.
- auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
- if (!operandValue) {
- for (Region &caseRegion : getCaseRegions())
- successors.emplace_back(&caseRegion);
- successors.emplace_back(&getDefaultRegion());
+ auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
+ if (!arg) {
+ llvm::copy(getRegions(), std::back_inserter(successors));
return;
}
- // Otherwise, try to find a case with a matching value. If not, the default
- // region is the only successor.
+ // Otherwise, try to find a case with a matching value. If not, the
+ // default region is the only successor.
for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
- if (caseValue == operandValue.getInt()) {
+ if (caseValue == arg.getInt()) {
successors.emplace_back(&caseRegion);
return;
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 3862b889723e5e..c52a9c5004aaaf 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -335,8 +335,7 @@ void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
void AssumingOp::getSuccessorRegions(
- std::optional<unsigned> index, ArrayRef<Attribute> operands,
- SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
// AssumingOp has unconditional control flow into the region and back to the
// parent, so return the correct RegionSuccessor purely based on the index
// being None or 0.
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 21284075f06a9a..fc64aef9681892 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -85,7 +85,7 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
// AlternativesOp
//===----------------------------------------------------------------------===//
-OperandRange transform::AlternativesOp::getSuccessorEntryOperands(
+OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
std::optional<unsigned> index) {
if (index && getOperation()->getNumOperands() == 1)
return getOperation()->getOperands();
@@ -94,8 +94,7 @@ OperandRange transform::AlternativesOp::getSuccessorEntryOperands(
}
void transform::AlternativesOp::getSuccessorRegions(
- std::optional<unsigned> index, ArrayRef<Attribute> operands,
- SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
for (Region &alternative : llvm::drop_begin(
getAlternatives(), index.has_value() ? *index + 1 : 0)) {
regions.emplace_back(&alternative, !getOperands().empty()
@@ -1162,8 +1161,7 @@ void transform::ForeachOp::getEffects(
}
void transform::ForeachOp::getSuccessorRegions(
- std::optional<unsigned> index, ArrayRef<Attribute> operands,
- SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
Region *bodyRegion = &getBody();
if (!index) {
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
@@ -1177,7 +1175,7 @@ void transform::ForeachOp::getSuccessorRegions(
}
OperandRange
-transform::ForeachOp::getSuccessorEntryOperands(std::optional<unsigned> index) {
+transform::ForeachOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
// The iteration variable op handle is mapped to a subset (one op to be
// precise) of the payload ops of the ForeachOp operand.
assert(index && *index == 0 && "unexpected region index");
@@ -2182,7 +2180,7 @@ void transform::SequenceOp::getEffects(
getPotentialTopLevelEffects(effects);
}
-OperandRange transform::SequenceOp::getSuccessorEntryOperands(
+OperandRange transform::SequenceOp::getEntrySuccessorOperands(
std::optional<unsigned> index) {
assert(index && *index == 0 && "unexpected region index");
if (getOperation()->getNumOperands() > 0)
@@ -2192,11 +2190,10 @@ OperandRange transform::SequenceOp::getSuccessorEntryOperands(
}
void transform::SequenceOp::getSuccessorRegions(
- std::optional<unsigned> index, ArrayRef<Attribute> operands,
- SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
if (!index) {
Region *bodyRegion = &getBody();
- regions.emplace_back(bodyRegion, !operands.empty()
+ regions.emplace_back(bodyRegion, getNumOperands() != 0
? bodyRegion->getArguments()
: Block::BlockArgListType());
return;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5b416e4a69996f..546f1b9c872ca6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5671,8 +5671,7 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
}
void WarpExecuteOnLane0Op::getSuccessorRegions(
- std::optional<unsigned> index, ArrayRef<Attribute> operands,
- SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
if (index) {
regions.push_back(RegionSuccessor(getResults()));
return;
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index e4eefaa450b89a..b3690ab8961555 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -154,7 +154,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
auto inputTypesFromParent =
[&](std::optional<unsigned> regionNo) -> TypeRange {
- return regionInterface.getSuccessorEntryOperands(regionNo).getTypes();
+ return regionInterface.getEntrySuccessorOperands(regionNo).getTypes();
};
// Verify types along control flow edges originating from the parent.
@@ -309,27 +309,6 @@ bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
return isRegionReachable(region, region);
}
-void RegionBranchOpInterface::getSuccessorRegions(
- std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
- unsigned numInputs = 0;
- if (index) {
- // If the predecessor is a region, get the number of operands from an
- // exiting terminator in the region.
- for (Block &block : getOperation()->getRegion(*index)) {
- Operation *terminator = block.getTerminator();
- if (isa<RegionBranchTerminatorOpInterface>(terminator)) {
- numInputs = terminator->getNumOperands();
- break;
- }
- }
- } else {
- // Otherwise, use the number of parent operation operands.
- numInputs = getOperation()->getNumOperands();
- }
- SmallVector<Attribute, 2> operands(numInputs, nullptr);
- getSuccessorRegions(index, operands, regions);
-}
-
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
while (Region *region = op->getParentRegion()) {
op = region->getParentOp();
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 485d21823eb67a..ed97efa462be9f 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -932,14 +932,13 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
}
OperandRange
-RegionIfOp::getSuccessorEntryOperands(std::optional<unsigned> index) {
+RegionIfOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
assert(index && *index < 2 && "invalid region index");
return getOperands();
}
void RegionIfOp::getSuccessorRegions(
- std::optional<unsigned> index, ArrayRef<Attribute> operands,
- SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
// We always branch to the join region.
if (index.has_value()) {
if (index.value() < 2)
@@ -966,7 +965,6 @@ void RegionIfOp::getRegionInvocationBounds(
//===----------------------------------------------------------------------===//
void AnyCondOp::getSuccessorRegions(std::optional<unsigned> index,
- ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The parent op branches into the only region, and the region branches back
// to the parent op.
@@ -1268,8 +1266,7 @@ MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
}
void TestStoreWithARegion::getSuccessorRegions(
- std::optional<unsigned> index, ArrayRef<Attribute> operands,
- SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
if (!index) {
regions.emplace_back(&getBody(), getBody().front().getArguments());
} else {
@@ -1277,11 +1274,6 @@ void TestStoreWithARegion::getSuccessorRegions(
}
}
-MutableOperandRange TestStoreWithARegionTerminator::getMutableSuccessorOperands(
- std::optional<unsigned> index) {
- return MutableOperandRange(getOperation());
-}
-
LogicalResult
TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader,
::mlir::OperationState &state) {
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 373aeebd0a767a..b3678868e17dec 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2047,7 +2047,8 @@ def RegionIfYieldOp : TEST_Op<"region_if_yield",
def RegionIfOp : TEST_Op<"region_if",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getRegionInvocationBounds"]>,
+ ["getRegionInvocationBounds",
+ "getEntrySuccessorOperands"]>,
SingleBlockImplicitTerminator<"RegionIfYieldOp">,
RecursiveMemoryEffects]> {
let description =[{
@@ -2071,8 +2072,6 @@ def RegionIfOp : TEST_Op<"region_if",
::mlir::Block::BlockArgListType getJoinArgs() {
return getBody(2)->getArguments();
}
- ::mlir::OperandRange getSuccessorEntryOperands(
- ::std::optional<unsigned> index);
}];
let hasCustomAssemblyFormat = 1;
}
@@ -2824,7 +2823,7 @@ def TestStoreWithARegion : TEST_Op<"store_with_a_region",
}
def TestStoreWithARegionTerminator : TEST_Op<"store_with_a_region_terminator",
- [DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>, Terminator, NoMemoryEffect]> {
+ [ReturnLike, Terminator, NoMemoryEffect]> {
let assemblyFormat = "attr-dict";
}
diff --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
index 230298cef85697..a507baa6445d97 100644
--- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
@@ -38,7 +38,6 @@ struct MutuallyExclusiveRegionsOp
// Regions have no successors.
void getSuccessorRegions(std::optional<unsigned> index,
- ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> ®ions) {}
};
@@ -53,7 +52,6 @@ struct LoopRegionsOp
static StringRef getOperationName() { return "cftest.loop_regions_op"; }
void getSuccessorRegions(std::optional<unsigned> index,
- ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> ®ions) {
if (index) {
if (*index == 1)
@@ -77,7 +75,6 @@ struct DoubleLoopRegionsOp
}
void getSuccessorRegions(std::optional<unsigned> index,
- ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> ®ions) {
if (index.has_value()) {
regions.push_back(RegionSuccessor());
@@ -96,7 +93,6 @@ struct SequentialRegionsOp
// Region 0 has Region 1 as a successor.
void getSuccessorRegions(std::optional<unsigned> index,
- ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> ®ions) {
if (index == 0u) {
Operation *thisOp = this->getOperation();
More information about the flang-commits
mailing list