[Mlir-commits] [mlir] e3c5471 - Revert " [MLIR] Revamp RegionBranchOpInterface " (#165356)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 28 01:06:19 PDT 2025
Author: Mehdi Amini
Date: 2025-10-28T01:06:14-07:00
New Revision: e3c547179f587299378397ac5c7f7eb8f4ca7976
URL: https://github.com/llvm/llvm-project/commit/e3c547179f587299378397ac5c7f7eb8f4ca7976
DIFF: https://github.com/llvm/llvm-project/commit/e3c547179f587299378397ac5c7f7eb8f4ca7976.diff
LOG: Revert " [MLIR] Revamp RegionBranchOpInterface " (#165356)
Reverts llvm/llvm-project#161575
Broke Windows on ARM buildbot build, needs investigations.
Added:
Modified:
flang/lib/Optimizer/Dialect/FIROps.cpp
mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
mlir/include/mlir/IR/Diagnostics.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/Region.h
mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
mlir/lib/Analysis/SliceWalk.cpp
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
mlir/lib/Dialect/EmitC/IR/EmitC.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
mlir/lib/IR/Diagnostics.cpp
mlir/lib/IR/Region.cpp
mlir/lib/Interfaces/ControlFlowInterfaces.cpp
mlir/lib/Transforms/RemoveDeadValues.cpp
mlir/test/Dialect/SCF/invalid.mlir
mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
mlir/test/lib/Dialect/Test/TestOpDefs.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 4f97acaa88b7a..d0164f32d9b6a 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -4484,7 +4484,7 @@ void fir::IfOp::getSuccessorRegions(
llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) {
// The `then` and the `else` region branch back to the parent operation.
if (!point.isParent()) {
- regions.push_back(mlir::RegionSuccessor(getOperation(), getResults()));
+ regions.push_back(mlir::RegionSuccessor(getResults()));
return;
}
@@ -4494,8 +4494,7 @@ void fir::IfOp::getSuccessorRegions(
// Don't consider the else region if it is empty.
mlir::Region *elseRegion = &this->getElseRegion();
if (elseRegion->empty())
- regions.push_back(
- mlir::RegionSuccessor(getOperation(), getOperation()->getResults()));
+ regions.push_back(mlir::RegionSuccessor());
else
regions.push_back(mlir::RegionSuccessor(elseRegion));
}
@@ -4514,7 +4513,7 @@ void fir::IfOp::getEntrySuccessorRegions(
if (!getElseRegion().empty())
regions.emplace_back(&getElseRegion());
else
- regions.emplace_back(getOperation(), getOperation()->getResults());
+ regions.emplace_back(getResults());
}
}
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 3c87c453a4cf0..8bcfe51ad7cd1 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -397,7 +397,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// itself.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionSuccessor regionTo, const AbstractDenseLattice &after,
+ RegionBranchPoint regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
meet(before, after);
}
@@ -526,7 +526,7 @@ class DenseBackwardDataFlowAnalysis
/// and "to" regions.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionSuccessor regionTo, const LatticeT &after, LatticeT *before) {
+ RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
branch, regionFrom, regionTo, after, before);
}
@@ -571,7 +571,7 @@ class DenseBackwardDataFlowAnalysis
}
void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionForm,
- RegionSuccessor regionTo, const AbstractDenseLattice &after,
+ RegionBranchPoint regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) final {
visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
static_cast<const LatticeT &>(after),
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 985573476ab78..1a33ecf8b5aa9 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -286,7 +286,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// and propagating therefrom.
virtual void
visitRegionSuccessors(ProgramPoint *point, RegionBranchOpInterface branch,
- RegionSuccessor successor,
+ RegionBranchPoint successor,
ArrayRef<AbstractSparseLattice *> lattices);
};
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 48690151caf01..fadd3fc10bfc4 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -644,13 +644,6 @@ def ForallOp : SCF_Op<"forall", [
/// Returns true if the mapping specified for this forall op is linear.
bool usesLinearMapping();
-
- /// RegionBranchOpInterface
-
- OperandRange getEntrySuccessorOperands(RegionSuccessor successor) {
- return getInits();
- }
-
}];
}
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index ed69287410509..62e66b3dabee8 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -25,7 +25,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
def AlternativesOp : TransformDialectOp<"alternatives",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getEntrySuccessorOperands",
+ ["getEntrySuccessorOperands", "getSuccessorRegions",
"getRegionInvocationBounds"]>,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -624,7 +624,7 @@ def ForeachOp : TransformDialectOp<"foreach",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<RegionBranchOpInterface, [
- "getEntrySuccessorOperands"]>,
+ "getSuccessorRegions", "getEntrySuccessorOperands"]>,
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
]> {
let summary = "Executes the body for each element of the payload";
@@ -1237,7 +1237,7 @@ def SelectOp : TransformDialectOp<"select",
def SequenceOp : TransformDialectOp<"sequence",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getEntrySuccessorOperands",
+ ["getEntrySuccessorOperands", "getSuccessorRegions",
"getRegionInvocationBounds"]>,
MatchOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>,
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
index 4079848fd203a..d095659fc4838 100644
--- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
@@ -63,7 +63,7 @@ def KnobOp : Op<Transform_Dialect, "tune.knob", [
def AlternativesOp : Op<Transform_Dialect, "tune.alternatives", [
DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getEntrySuccessorOperands",
+ ["getEntrySuccessorOperands", "getSuccessorRegions",
"getRegionInvocationBounds"]>,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index a0a99f4953822..7ff718ad7f241 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -29,7 +29,6 @@ class MLIRContext;
class Operation;
class OperationName;
class OpPrintingFlags;
-class OpWithFlags;
class Type;
class Value;
@@ -200,7 +199,6 @@ class Diagnostic {
/// Stream in an Operation.
Diagnostic &operator<<(Operation &op);
- Diagnostic &operator<<(OpWithFlags op);
Diagnostic &operator<<(Operation *op) { return *this << *op; }
/// Append an operation with the given printing flags.
Diagnostic &appendOp(Operation &op, const OpPrintingFlags &flags);
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index b2019574a820d..5569392cf0b41 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -1114,7 +1114,6 @@ class OpWithFlags {
: op(op), theFlags(flags) {}
OpPrintingFlags &flags() { return theFlags; }
const OpPrintingFlags &flags() const { return theFlags; }
- Operation *getOperation() const { return op; }
private:
Operation *op;
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 53d461df98710..1fcb316750230 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -379,8 +379,6 @@ class RegionRange
friend RangeBaseT;
};
-llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Region ®ion);
-
} // namespace mlir
#endif // MLIR_IR_REGION_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 47afd252c6d68..d63800c12d132 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -15,16 +15,10 @@
#define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/Operation.h"
-#include "llvm/ADT/PointerUnion.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/DebugLog.h"
-#include "llvm/Support/raw_ostream.h"
namespace mlir {
class BranchOpInterface;
class RegionBranchOpInterface;
-class RegionBranchTerminatorOpInterface;
/// This class models how operands are forwarded to block arguments in control
/// flow. It consists of a number, denoting how many of the successors block
@@ -192,40 +186,27 @@ class RegionSuccessor {
public:
/// Initialize a successor that branches to another region of the parent
/// operation.
- /// TODO: the default value for the regionInputs is somehow broken.
- /// A region successor should have its input correctly set.
RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {})
- : successor(region), inputs(regionInputs) {
- assert(region && "Region must not be null");
- }
+ : region(region), inputs(regionInputs) {}
/// Initialize a successor that branches back to/out of the parent operation.
- /// The target must be one of the recursive parent operations.
- RegionSuccessor(Operation *successorOp, Operation::result_range results)
- : successor(successorOp), inputs(ValueRange(results)) {
- assert(successorOp && "Successor op must not be null");
- }
+ RegionSuccessor(Operation::result_range results)
+ : inputs(ValueRange(results)) {}
+ /// Constructor with no arguments.
+ RegionSuccessor() : inputs(ValueRange()) {}
/// Return the given region successor. Returns nullptr if the successor is the
/// parent operation.
- Region *getSuccessor() const { return dyn_cast<Region *>(successor); }
+ Region *getSuccessor() const { return region; }
/// Return true if the successor is the parent operation.
- bool isParent() const { return isa<Operation *>(successor); }
+ bool isParent() const { return region == nullptr; }
/// Return the inputs to the successor that are remapped by the exit values of
/// the current region.
ValueRange getSuccessorInputs() const { return inputs; }
- bool operator==(RegionSuccessor rhs) const {
- return successor == rhs.successor && inputs == rhs.inputs;
- }
-
- friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) {
- return !(lhs == rhs);
- }
-
private:
- llvm::PointerUnion<Region *, Operation *> successor{nullptr};
+ Region *region{nullptr};
ValueRange inputs;
};
@@ -233,67 +214,64 @@ class RegionSuccessor {
/// `RegionBranchOpInterface`.
/// One can branch from one of two kinds of places:
/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
-/// * A RegionBranchTerminatorOpInterface inside a region within the parent
-// operation.
+/// * A region within the parent operation.
class RegionBranchPoint {
public:
/// Returns an instance of `RegionBranchPoint` representing the parent
/// operation.
static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); }
- /// Creates a `RegionBranchPoint` that branches from the given terminator.
- inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor);
+ /// Creates a `RegionBranchPoint` that branches from the given region.
+ /// The pointer must not be null.
+ RegionBranchPoint(Region *region) : maybeRegion(region) {
+ assert(region && "Region must not be null");
+ }
+
+ RegionBranchPoint(Region ®ion) : RegionBranchPoint(®ion) {}
/// Explicitly stops users from constructing with `nullptr`.
RegionBranchPoint(std::nullptr_t) = delete;
+ /// Constructs a `RegionBranchPoint` from the the target of a
+ /// `RegionSuccessor` instance.
+ RegionBranchPoint(RegionSuccessor successor) {
+ if (successor.isParent())
+ maybeRegion = nullptr;
+ else
+ maybeRegion = successor.getSuccessor();
+ }
+
+ /// Assigns a region being branched from.
+ RegionBranchPoint &operator=(Region ®ion) {
+ maybeRegion = ®ion;
+ return *this;
+ }
+
/// Returns true if branching from the parent op.
- bool isParent() const { return predecessor == nullptr; }
+ bool isParent() const { return maybeRegion == nullptr; }
- /// Returns the terminator if branching from a region.
+ /// Returns the region if branching from a region.
/// A null pointer otherwise.
- Operation *getTerminatorPredecessorOrNull() const { return predecessor; }
+ Region *getRegionOrNull() const { return maybeRegion; }
/// Returns true if the two branch points are equal.
friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
- return lhs.predecessor == rhs.predecessor;
+ return lhs.maybeRegion == rhs.maybeRegion;
}
private:
// Private constructor to encourage the use of `RegionBranchPoint::parent`.
- constexpr RegionBranchPoint() = default;
+ constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
/// Internal encoding. Uses nullptr for representing branching from the parent
- /// op and the region terminator being branched from otherwise.
- Operation *predecessor = nullptr;
+ /// op and the region being branched from otherwise.
+ Region *maybeRegion;
};
inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
return !(lhs == rhs);
}
-inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
- RegionBranchPoint point) {
- if (point.isParent())
- return os << "<from parent>";
- return os << "<region #"
- << point.getTerminatorPredecessorOrNull()
- ->getParentRegion()
- ->getRegionNumber()
- << ", terminator "
- << OpWithFlags(point.getTerminatorPredecessorOrNull(),
- OpPrintingFlags().skipRegions())
- << ">";
-}
-
-inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
- RegionSuccessor successor) {
- if (successor.isParent())
- return os << "<to parent>";
- return os << "<to region #" << successor.getSuccessor()->getRegionNumber()
- << " with " << successor.getSuccessorInputs().size() << " inputs>";
-}
-
/// This class represents upper and lower bounds on the number of times a region
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
/// zero, but the upper bound may not be known.
@@ -370,10 +348,4 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {
/// Include the generated interface declarations.
#include "mlir/Interfaces/ControlFlowInterfaces.h.inc"
-namespace mlir {
-inline RegionBranchPoint::RegionBranchPoint(
- RegionBranchTerminatorOpInterface predecessor)
- : predecessor(predecessor.getOperation()) {}
-} // namespace mlir
-
#endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 94242e3ba39ce..b8d08cc553caa 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -117,7 +117,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
let description = [{
- This interface provides information for region-holding operations that exhibit
+ This interface provides information for region operations that exhibit
branching behavior between held regions. I.e., this interface allows for
expressing control flow information for region holding operations.
@@ -126,12 +126,12 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
be side-effect free.
A "region branch point" indicates a point from which a branch originates. It
- can indicate either a terminator in any of the immediately nested region of
- this op or `RegionBranchPoint::parent()`. In the latter case, the branch
- originates from outside of the op, i.e., when first executing this op.
+ can indicate either a region of this op or `RegionBranchPoint::parent()`. In
+ the latter case, the branch originates from outside of the op, i.e., when
+ first executing this op.
A "region successor" indicates the target of a branch. It can indicate
- either a region of this op or this op itself. In the former case, the region
+ either a region of this op or this op. In the former case, the region
successor is a region pointer and a range of block arguments to which the
"successor operands" are forwarded to. In the latter case, the control flow
leaves this op and the region successor is a range of results of this op to
@@ -151,10 +151,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
}
```
- `scf.for` has one region. The `scf.yield` has two region successors: the
- region body itself and the `scf.for` op. `%b` is an entry successor
- operand. `%c` is a successor operand. `%a` is a successor block argument.
- `%r` is a successor result.
+ `scf.for` has one region. The region has two region successors: the region
+ itself and the `scf.for` op. %b is an entry successor operand. %c is a
+ successor operand. %a is a successor block argument. %r is a successor
+ result.
}];
let cppNamespace = "::mlir";
@@ -162,16 +162,16 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
InterfaceMethod<[{
Returns the operands of this operation that are forwarded to the region
successor's block arguments or this operation's results when branching
- to `successor`. `successor` is guaranteed to be among the successors that are
+ to `point`. `point` is guaranteed to be among the successors that are
returned by `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`.
Example: In the above example, this method returns the operand %b of the
- `scf.for` op, regardless of the value of `successor`. I.e., this op always
+ `scf.for` op, regardless of the value of `point`. I.e., this op always
forwards the same operands, regardless of whether the loop has 0 or more
iterations.
}],
"::mlir::OperandRange", "getEntrySuccessorOperands",
- (ins "::mlir::RegionSuccessor":$successor), [{}],
+ (ins "::mlir::RegionBranchPoint":$point), [{}],
/*defaultImplementation=*/[{
auto operandEnd = this->getOperation()->operand_end();
return ::mlir::OperandRange(operandEnd, operandEnd);
@@ -224,80 +224,6 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
(ins "::mlir::RegionBranchPoint":$point,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
>,
- InterfaceMethod<[{
- Returns the potential region successors when branching from any
- terminator in `region`.
- These are the regions that may be selected during the flow of control.
- }],
- "void", "getSuccessorRegions",
- (ins "::mlir::Region&":$region,
- "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
- [{}],
- /*defaultImplementation=*/[{
- for (::mlir::Block &block : region) {
- if (block.empty())
- continue;
- if (auto terminator =
- dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
- $_op.getSuccessorRegions(RegionBranchPoint(terminator),
- regions);
- }
- }]>,
- InterfaceMethod<[{
- Returns the potential branching point (predecessors) for a given successor.
- }],
- "void", "getPredecessors",
- (ins "::mlir::RegionSuccessor":$successor,
- "::llvm::SmallVectorImpl<::mlir::RegionBranchPoint> &":$predecessors),
- [{}],
- /*defaultImplementation=*/[{
- ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
- $_op.getSuccessorRegions(RegionBranchPoint::parent(),
- successors);
- if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
- return succ.getSuccessor() == successor.getSuccessor() ||
- (succ.isParent() && successor.isParent());
- }))
- predecessors.push_back(RegionBranchPoint::parent());
- for (Region ®ion : $_op->getRegions()) {
- for (::mlir::Block &block : region) {
- if (block.empty())
- continue;
- if (auto terminator =
- dyn_cast<RegionBranchTerminatorOpInterface>(block.back())) {
- ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
- $_op.getSuccessorRegions(RegionBranchPoint(terminator),
- successors);
- if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
- return succ.getSuccessor() == successor.getSuccessor() ||
- (succ.isParent() && successor.isParent());
- }))
- predecessors.push_back(terminator);
- }
- }
- }
- }]>,
- InterfaceMethod<[{
- Returns the potential values across all (predecessors) for a given successor
- input, modeled by its index (its position in the list of values).
- }],
- "void", "getPredecessorValues",
- (ins "::mlir::RegionSuccessor":$successor,
- "int":$index,
- "::llvm::SmallVectorImpl<::mlir::Value> &":$predecessorValues),
- [{}],
- /*defaultImplementation=*/[{
- ::llvm::SmallVector<::mlir::RegionBranchPoint> predecessors;
- $_op.getPredecessors(successor, predecessors);
- for (auto predecessor : predecessors) {
- if (predecessor.isParent()) {
- predecessorValues.push_back($_op.getEntrySuccessorOperands(successor)[index]);
- continue;
- }
- auto terminator = cast<RegionBranchTerminatorOpInterface>(predecessor.getTerminatorPredecessorOrNull());
- predecessorValues.push_back(terminator.getSuccessorOperands(successor)[index]);
- }
- }]>,
InterfaceMethod<[{
Populates `invocationBounds` with the minimum and maximum number of
times this operation will invoke the attached regions (assuming the
@@ -372,7 +298,7 @@ def RegionBranchTerminatorOpInterface :
passing them to the region successor indicated by `point`.
}],
"::mlir::MutableOperandRange", "getMutableSuccessorOperands",
- (ins "::mlir::RegionSuccessor":$point)
+ (ins "::mlir::RegionBranchPoint":$point)
>,
InterfaceMethod<[{
Returns the potential region successors that are branched to after this
@@ -391,7 +317,7 @@ def RegionBranchTerminatorOpInterface :
/*defaultImplementation=*/[{
::mlir::Operation *op = $_op;
::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
- .getSuccessorRegions(::llvm::cast<::mlir::RegionBranchTerminatorOpInterface>(op), regions);
+ .getSuccessorRegions(op->getParentRegion(), regions);
}]
>,
];
@@ -411,8 +337,8 @@ def RegionBranchTerminatorOpInterface :
// them to the region successor given by `index`. If `index` is None, this
// function returns the operands that are passed as a result to the parent
// operation.
- ::mlir::OperandRange getSuccessorOperands(::mlir::RegionSuccessor successor) {
- return getMutableSuccessorOperands(successor);
+ ::mlir::OperandRange getSuccessorOperands(::mlir::RegionBranchPoint point) {
+ return getMutableSuccessorOperands(point);
}
}];
}
@@ -578,7 +504,7 @@ def ReturnLike : TraitList<[
/*extraOpDeclaration=*/"",
/*extraOpDefinition=*/[{
::mlir::MutableOperandRange $cppClass::getMutableSuccessorOperands(
- ::mlir::RegionSuccessor successor) {
+ ::mlir::RegionBranchPoint point) {
return ::mlir::MutableOperandRange(*this);
}
}]
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index 24cb123e51877..a84d10d5d609d 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -16,21 +16,19 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/Casting.h"
-#include "llvm/Support/DebugLog.h"
#include <cassert>
#include <optional>
#include <utility>
using namespace mlir;
-#define DEBUG_TYPE "local-alias-analysis"
-
//===----------------------------------------------------------------------===//
// Underlying Address Computation
//===----------------------------------------------------------------------===//
@@ -44,47 +42,81 @@ static void collectUnderlyingAddressValues(Value value, unsigned maxDepth,
DenseSet<Value> &visited,
SmallVectorImpl<Value> &output);
-/// Given a RegionBranchOpInterface operation (`branch`), a Value`inputValue`
-/// which is an input for the provided successor (`initialSuccessor`), try to
-/// find the possible sources for the value along the control flow edges.
-static void collectUnderlyingAddressValues2(
- RegionBranchOpInterface branch, RegionSuccessor initialSuccessor,
- Value inputValue, unsigned inputIndex, unsigned maxDepth,
- DenseSet<Value> &visited, SmallVectorImpl<Value> &output) {
- LDBG() << "collectUnderlyingAddressValues2: "
- << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions());
- LDBG() << " with initialSuccessor " << initialSuccessor;
- LDBG() << " inputValue: " << inputValue;
- LDBG() << " inputIndex: " << inputIndex;
- LDBG() << " maxDepth: " << maxDepth;
- ValueRange inputs = initialSuccessor.getSuccessorInputs();
- if (inputs.empty()) {
- LDBG() << " input is empty, enqueue value";
- output.push_back(inputValue);
- return;
- }
- unsigned firstInputIndex, lastInputIndex;
- if (isa<BlockArgument>(inputs[0])) {
- firstInputIndex = cast<BlockArgument>(inputs[0]).getArgNumber();
- lastInputIndex = cast<BlockArgument>(inputs.back()).getArgNumber();
- } else {
- firstInputIndex = cast<OpResult>(inputs[0]).getResultNumber();
- lastInputIndex = cast<OpResult>(inputs.back()).getResultNumber();
- }
- if (firstInputIndex > inputIndex || lastInputIndex < inputIndex) {
- LDBG() << " !! Input index " << inputIndex << " out of range "
- << firstInputIndex << " to " << lastInputIndex
- << ", adding input value to output";
- output.push_back(inputValue);
- return;
+/// Given a successor (`region`) of a RegionBranchOpInterface, collect all of
+/// the underlying values being addressed by one of the successor inputs. If the
+/// provided `region` is null, as per `RegionBranchOpInterface` this represents
+/// the parent operation.
+static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
+ Region *region, Value inputValue,
+ unsigned inputIndex,
+ unsigned maxDepth,
+ DenseSet<Value> &visited,
+ SmallVectorImpl<Value> &output) {
+ // Given the index of a region of the branch (`predIndex`), or std::nullopt to
+ // represent the parent operation, try to return the index into the outputs of
+ // this region predecessor that correspond to the input values of `region`. If
+ // an index could not be found, std::nullopt is returned instead.
+ auto getOperandIndexIfPred =
+ [&](RegionBranchPoint pred) -> std::optional<unsigned> {
+ SmallVector<RegionSuccessor, 2> successors;
+ branch.getSuccessorRegions(pred, successors);
+ for (RegionSuccessor &successor : successors) {
+ if (successor.getSuccessor() != region)
+ continue;
+ // Check that the successor inputs map to the given input value.
+ ValueRange inputs = successor.getSuccessorInputs();
+ if (inputs.empty()) {
+ output.push_back(inputValue);
+ break;
+ }
+ unsigned firstInputIndex, lastInputIndex;
+ if (region) {
+ firstInputIndex = cast<BlockArgument>(inputs[0]).getArgNumber();
+ lastInputIndex = cast<BlockArgument>(inputs.back()).getArgNumber();
+ } else {
+ firstInputIndex = cast<OpResult>(inputs[0]).getResultNumber();
+ lastInputIndex = cast<OpResult>(inputs.back()).getResultNumber();
+ }
+ if (firstInputIndex > inputIndex || lastInputIndex < inputIndex) {
+ output.push_back(inputValue);
+ break;
+ }
+ return inputIndex - firstInputIndex;
+ }
+ return std::nullopt;
+ };
+
+ // Check branches from the parent operation.
+ auto branchPoint = RegionBranchPoint::parent();
+ if (region)
+ branchPoint = region;
+
+ if (std::optional<unsigned> operandIndex =
+ getOperandIndexIfPred(/*predIndex=*/RegionBranchPoint::parent())) {
+ collectUnderlyingAddressValues(
+ branch.getEntrySuccessorOperands(branchPoint)[*operandIndex], maxDepth,
+ visited, output);
}
- SmallVector<Value> predecessorValues;
- branch.getPredecessorValues(initialSuccessor, inputIndex - firstInputIndex,
- predecessorValues);
- LDBG() << " Found " << predecessorValues.size() << " predecessor values";
- for (Value predecessorValue : predecessorValues) {
- LDBG() << " Processing predecessor value: " << predecessorValue;
- collectUnderlyingAddressValues(predecessorValue, maxDepth, visited, output);
+ // Check branches from each child region.
+ Operation *op = branch.getOperation();
+ for (Region ®ion : op->getRegions()) {
+ if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(region)) {
+ for (Block &block : region) {
+ // Try to determine possible region-branch successor operands for the
+ // current region.
+ if (auto term = dyn_cast<RegionBranchTerminatorOpInterface>(
+ block.getTerminator())) {
+ collectUnderlyingAddressValues(
+ term.getSuccessorOperands(branchPoint)[*operandIndex], maxDepth,
+ visited, output);
+ } else if (block.getNumSuccessors()) {
+ // Otherwise, if this terminator may exit the region we can't make
+ // any assumptions about which values get passed.
+ output.push_back(inputValue);
+ return;
+ }
+ }
+ }
}
}
@@ -92,28 +124,22 @@ static void collectUnderlyingAddressValues2(
static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth,
DenseSet<Value> &visited,
SmallVectorImpl<Value> &output) {
- LDBG() << "collectUnderlyingAddressValues (OpResult): " << result;
- LDBG() << " maxDepth: " << maxDepth;
-
Operation *op = result.getOwner();
// If this is a view, unwrap to the source.
if (ViewLikeOpInterface view = dyn_cast<ViewLikeOpInterface>(op)) {
if (result == view.getViewDest()) {
- LDBG() << " Unwrapping view to source: " << view.getViewSource();
return collectUnderlyingAddressValues(view.getViewSource(), maxDepth,
visited, output);
}
}
// Check to see if we can reason about the control flow of this op.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
- LDBG() << " Processing region branch operation";
- return collectUnderlyingAddressValues2(
- branch, RegionSuccessor(op, op->getResults()), result,
- result.getResultNumber(), maxDepth, visited, output);
+ return collectUnderlyingAddressValues(branch, /*region=*/nullptr, result,
+ result.getResultNumber(), maxDepth,
+ visited, output);
}
- LDBG() << " Adding result to output: " << result;
output.push_back(result);
}
@@ -122,23 +148,14 @@ static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth,
static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth,
DenseSet<Value> &visited,
SmallVectorImpl<Value> &output) {
- LDBG() << "collectUnderlyingAddressValues (BlockArgument): " << arg;
- LDBG() << " maxDepth: " << maxDepth;
- LDBG() << " argNumber: " << arg.getArgNumber();
- LDBG() << " isEntryBlock: " << arg.getOwner()->isEntryBlock();
-
Block *block = arg.getOwner();
unsigned argNumber = arg.getArgNumber();
// Handle the case of a non-entry block.
if (!block->isEntryBlock()) {
- LDBG() << " Processing non-entry block with "
- << std::distance(block->pred_begin(), block->pred_end())
- << " predecessors";
for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
auto branch = dyn_cast<BranchOpInterface>((*it)->getTerminator());
if (!branch) {
- LDBG() << " Cannot analyze control flow, adding argument to output";
// We can't analyze the control flow, so bail out early.
output.push_back(arg);
return;
@@ -148,12 +165,10 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth,
unsigned index = it.getSuccessorIndex();
Value operand = branch.getSuccessorOperands(index)[argNumber];
if (!operand) {
- LDBG() << " No operand found for argument, adding to output";
// We can't analyze the control flow, so bail out early.
output.push_back(arg);
return;
}
- LDBG() << " Processing operand from predecessor: " << operand;
collectUnderlyingAddressValues(operand, maxDepth, visited, output);
}
return;
@@ -163,35 +178,10 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth,
Region *region = block->getParent();
Operation *op = region->getParentOp();
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
- LDBG() << " Processing region branch operation for entry block";
- // We have to find the successor matching the region, so that the input
- // arguments are correctly set.
- // TODO: this isn't comprehensive: the successor may not be reachable from
- // the entry block.
- SmallVector<RegionSuccessor> successors;
- branch.getSuccessorRegions(RegionBranchPoint::parent(), successors);
- RegionSuccessor regionSuccessor(region);
- bool found = false;
- for (RegionSuccessor &successor : successors) {
- if (successor.getSuccessor() == region) {
- LDBG() << " Found matching region successor: " << successor;
- found = true;
- regionSuccessor = successor;
- break;
- }
- }
- if (!found) {
- LDBG()
- << " No matching region successor found, adding argument to output";
- output.push_back(arg);
- return;
- }
- return collectUnderlyingAddressValues2(
- branch, regionSuccessor, arg, argNumber, maxDepth, visited, output);
+ return collectUnderlyingAddressValues(branch, region, arg, argNumber,
+ maxDepth, visited, output);
}
- LDBG()
- << " Cannot reason about underlying address, adding argument to output";
// We can't reason about the underlying address of this argument.
output.push_back(arg);
}
@@ -200,26 +190,17 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth,
static void collectUnderlyingAddressValues(Value value, unsigned maxDepth,
DenseSet<Value> &visited,
SmallVectorImpl<Value> &output) {
- LDBG() << "collectUnderlyingAddressValues: " << value;
- LDBG() << " maxDepth: " << maxDepth;
-
// Check that we don't infinitely recurse.
- if (!visited.insert(value).second) {
- LDBG() << " Value already visited, skipping";
+ if (!visited.insert(value).second)
return;
- }
if (maxDepth == 0) {
- LDBG() << " Max depth reached, adding value to output";
output.push_back(value);
return;
}
--maxDepth;
- if (BlockArgument arg = dyn_cast<BlockArgument>(value)) {
- LDBG() << " Processing as BlockArgument";
+ if (BlockArgument arg = dyn_cast<BlockArgument>(value))
return collectUnderlyingAddressValues(arg, maxDepth, visited, output);
- }
- LDBG() << " Processing as OpResult";
collectUnderlyingAddressValues(cast<OpResult>(value), maxDepth, visited,
output);
}
@@ -227,11 +208,9 @@ static void collectUnderlyingAddressValues(Value value, unsigned maxDepth,
/// Given a value, collect all of the underlying values being addressed.
static void collectUnderlyingAddressValues(Value value,
SmallVectorImpl<Value> &output) {
- LDBG() << "collectUnderlyingAddressValues: " << value;
DenseSet<Value> visited;
collectUnderlyingAddressValues(value, maxUnderlyingValueSearchDepth, visited,
output);
- LDBG() << " Collected " << output.size() << " underlying values";
}
//===----------------------------------------------------------------------===//
@@ -248,33 +227,19 @@ static LogicalResult
getAllocEffectFor(Value value,
std::optional<MemoryEffects::EffectInstance> &effect,
Operation *&allocScopeOp) {
- LDBG() << "getAllocEffectFor: " << value;
-
// Try to get a memory effect interface for the parent operation.
Operation *op;
- if (BlockArgument arg = dyn_cast<BlockArgument>(value)) {
+ if (BlockArgument arg = dyn_cast<BlockArgument>(value))
op = arg.getOwner()->getParentOp();
- LDBG() << " BlockArgument, parent op: "
- << OpWithFlags(op, OpPrintingFlags().skipRegions());
- } else {
+ else
op = cast<OpResult>(value).getOwner();
- LDBG() << " OpResult, owner op: "
- << OpWithFlags(op, OpPrintingFlags().skipRegions());
- }
-
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
- if (!interface) {
- LDBG() << " No memory effect interface found";
+ if (!interface)
return failure();
- }
// Try to find an allocation effect on the resource.
- if (!(effect = interface.getEffectOnValue<MemoryEffects::Allocate>(value))) {
- LDBG() << " No allocation effect found on value";
+ if (!(effect = interface.getEffectOnValue<MemoryEffects::Allocate>(value)))
return failure();
- }
-
- LDBG() << " Found allocation effect";
// If we found an allocation effect, try to find a scope for the allocation.
// If the resource of this allocation is automatically scoped, find the parent
@@ -282,12 +247,6 @@ getAllocEffectFor(Value value,
if (llvm::isa<SideEffects::AutomaticAllocationScopeResource>(
effect->getResource())) {
allocScopeOp = op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
- if (allocScopeOp) {
- LDBG() << " Automatic allocation scope found: "
- << OpWithFlags(allocScopeOp, OpPrintingFlags().skipRegions());
- } else {
- LDBG() << " Automatic allocation scope found: null";
- }
return success();
}
@@ -296,12 +255,6 @@ getAllocEffectFor(Value value,
// For now assume allocation scope to the function scope (we don't care if
// pointer escape outside function).
allocScopeOp = op->getParentOfType<FunctionOpInterface>();
- if (allocScopeOp) {
- LDBG() << " Function scope found: "
- << OpWithFlags(allocScopeOp, OpPrintingFlags().skipRegions());
- } else {
- LDBG() << " Function scope found: null";
- }
return success();
}
@@ -340,44 +293,33 @@ static std::optional<AliasResult> checkDistinctObjects(Value lhs, Value rhs) {
/// Given the two values, return their aliasing behavior.
AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
- LDBG() << "aliasImpl: " << lhs << " vs " << rhs;
-
- if (lhs == rhs) {
- LDBG() << " Same value, must alias";
+ if (lhs == rhs)
return AliasResult::MustAlias;
- }
-
Operation *lhsAllocScope = nullptr, *rhsAllocScope = nullptr;
std::optional<MemoryEffects::EffectInstance> lhsAlloc, rhsAlloc;
// Handle the case where lhs is a constant.
Attribute lhsAttr, rhsAttr;
if (matchPattern(lhs, m_Constant(&lhsAttr))) {
- LDBG() << " lhs is constant";
// TODO: This is overly conservative. Two matching constants don't
// necessarily map to the same address. For example, if the two values
// correspond to
diff erent symbols that both represent a definition.
- if (matchPattern(rhs, m_Constant(&rhsAttr))) {
- LDBG() << " rhs is also constant, may alias";
+ if (matchPattern(rhs, m_Constant(&rhsAttr)))
return AliasResult::MayAlias;
- }
// Try to find an alloc effect on rhs. If an effect was found we can't
// alias, otherwise we might.
- bool rhsHasAlloc =
- succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope));
- LDBG() << " rhs has alloc effect: " << rhsHasAlloc;
- return rhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias;
+ return succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope))
+ ? AliasResult::NoAlias
+ : AliasResult::MayAlias;
}
// Handle the case where rhs is a constant.
if (matchPattern(rhs, m_Constant(&rhsAttr))) {
- LDBG() << " rhs is constant";
// Try to find an alloc effect on lhs. If an effect was found we can't
// alias, otherwise we might.
- bool lhsHasAlloc =
- succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope));
- LDBG() << " lhs has alloc effect: " << lhsHasAlloc;
- return lhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias;
+ return succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope))
+ ? AliasResult::NoAlias
+ : AliasResult::MayAlias;
}
if (std::optional<AliasResult> result = checkDistinctObjects(lhs, rhs))
@@ -387,14 +329,9 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
// an allocation effect.
bool lhsHasAlloc = succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope));
bool rhsHasAlloc = succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope));
- LDBG() << " lhs has alloc effect: " << lhsHasAlloc;
- LDBG() << " rhs has alloc effect: " << rhsHasAlloc;
-
if (lhsHasAlloc == rhsHasAlloc) {
// If both values have an allocation effect we know they don't alias, and if
// neither have an effect we can't make an assumptions.
- LDBG() << " Both have same alloc status: "
- << (lhsHasAlloc ? "NoAlias" : "MayAlias");
return lhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias;
}
@@ -402,7 +339,6 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
// and one without. Move the one with the effect to the lhs to make the next
// checks simpler.
if (rhsHasAlloc) {
- LDBG() << " Swapping lhs and rhs to put alloc effect on lhs";
std::swap(lhs, rhs);
lhsAlloc = rhsAlloc;
lhsAllocScope = rhsAllocScope;
@@ -411,74 +347,49 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
// If the effect has a scoped allocation region, check to see if the
// non-effect value is defined above that scope.
if (lhsAllocScope) {
- LDBG() << " Checking allocation scope: "
- << OpWithFlags(lhsAllocScope, OpPrintingFlags().skipRegions());
// If the parent operation of rhs is an ancestor of the allocation scope, or
// if rhs is an entry block argument of the allocation scope we know the two
// values can't alias.
Operation *rhsParentOp = rhs.getParentRegion()->getParentOp();
- if (rhsParentOp->isProperAncestor(lhsAllocScope)) {
- LDBG() << " rhs parent is ancestor of alloc scope, no alias";
+ if (rhsParentOp->isProperAncestor(lhsAllocScope))
return AliasResult::NoAlias;
- }
if (rhsParentOp == lhsAllocScope) {
BlockArgument rhsArg = dyn_cast<BlockArgument>(rhs);
- if (rhsArg && rhs.getParentBlock()->isEntryBlock()) {
- LDBG() << " rhs is entry block arg of alloc scope, no alias";
+ if (rhsArg && rhs.getParentBlock()->isEntryBlock())
return AliasResult::NoAlias;
- }
}
}
// If we couldn't reason about the relationship between the two values,
// conservatively assume they might alias.
- LDBG() << " Cannot reason about relationship, may alias";
return AliasResult::MayAlias;
}
/// Given the two values, return their aliasing behavior.
AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) {
- LDBG() << "alias: " << lhs << " vs " << rhs;
-
- if (lhs == rhs) {
- LDBG() << " Same value, must alias";
+ if (lhs == rhs)
return AliasResult::MustAlias;
- }
// Get the underlying values being addressed.
SmallVector<Value, 8> lhsValues, rhsValues;
collectUnderlyingAddressValues(lhs, lhsValues);
collectUnderlyingAddressValues(rhs, rhsValues);
- LDBG() << " lhs underlying values: " << lhsValues.size();
- LDBG() << " rhs underlying values: " << rhsValues.size();
-
// If we failed to collect for either of the values somehow, conservatively
// assume they may alias.
- if (lhsValues.empty() || rhsValues.empty()) {
- LDBG() << " Failed to collect underlying values, may alias";
+ if (lhsValues.empty() || rhsValues.empty())
return AliasResult::MayAlias;
- }
// Check the alias results against each of the underlying values.
std::optional<AliasResult> result;
for (Value lhsVal : lhsValues) {
for (Value rhsVal : rhsValues) {
- LDBG() << " Checking underlying values: " << lhsVal << " vs " << rhsVal;
AliasResult nextResult = aliasImpl(lhsVal, rhsVal);
- LDBG() << " Result: "
- << (nextResult == AliasResult::MustAlias ? "MustAlias"
- : nextResult == AliasResult::NoAlias ? "NoAlias"
- : "MayAlias");
result = result ? result->merge(nextResult) : nextResult;
}
}
// We should always have a valid result here.
- LDBG() << " Final result: "
- << (result->isMust() ? "MustAlias"
- : result->isNo() ? "NoAlias"
- : "MayAlias");
return *result;
}
@@ -487,12 +398,8 @@ AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) {
//===----------------------------------------------------------------------===//
ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) {
- LDBG() << "getModRef: " << OpWithFlags(op, OpPrintingFlags().skipRegions())
- << " on location " << location;
-
// Check to see if this operation relies on nested side effects.
if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
- LDBG() << " Operation has recursive memory effects, returning ModAndRef";
// TODO: To check recursive operations we need to check all of the nested
// operations, which can result in a quadratic number of queries. We should
// introduce some caching of some kind to help alleviate this, especially as
@@ -503,64 +410,38 @@ ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) {
// Otherwise, check to see if this operation has a memory effect interface.
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
- if (!interface) {
- LDBG() << " No memory effect interface, returning ModAndRef";
+ if (!interface)
return ModRefResult::getModAndRef();
- }
// Build a ModRefResult by merging the behavior of the effects of this
// operation.
SmallVector<MemoryEffects::EffectInstance> effects;
interface.getEffects(effects);
- LDBG() << " Found " << effects.size() << " memory effects";
ModRefResult result = ModRefResult::getNoModRef();
for (const MemoryEffects::EffectInstance &effect : effects) {
- if (isa<MemoryEffects::Allocate, MemoryEffects::Free>(effect.getEffect())) {
- LDBG() << " Skipping alloc/free effect";
+ if (isa<MemoryEffects::Allocate, MemoryEffects::Free>(effect.getEffect()))
continue;
- }
// Check for an alias between the effect and our memory location.
// TODO: Add support for checking an alias with a symbol reference.
AliasResult aliasResult = AliasResult::MayAlias;
- if (Value effectValue = effect.getValue()) {
- LDBG() << " Checking alias between effect value " << effectValue
- << " and location " << location;
+ if (Value effectValue = effect.getValue())
aliasResult = alias(effectValue, location);
- LDBG() << " Alias result: "
- << (aliasResult.isMust() ? "MustAlias"
- : aliasResult.isNo() ? "NoAlias"
- : "MayAlias");
- } else {
- LDBG() << " No effect value, assuming MayAlias";
- }
// If we don't alias, ignore this effect.
- if (aliasResult.isNo()) {
- LDBG() << " No alias, ignoring effect";
+ if (aliasResult.isNo())
continue;
- }
// Merge in the corresponding mod or ref for this effect.
if (isa<MemoryEffects::Read>(effect.getEffect())) {
- LDBG() << " Adding Ref to result";
result = result.merge(ModRefResult::getRef());
} else {
assert(isa<MemoryEffects::Write>(effect.getEffect()));
- LDBG() << " Adding Mod to result";
result = result.merge(ModRefResult::getMod());
}
- if (result.isModAndRef()) {
- LDBG() << " Result is now ModAndRef, breaking";
+ if (result.isModAndRef())
break;
- }
}
-
- LDBG() << " Final ModRef result: "
- << (result.isModAndRef() ? "ModAndRef"
- : result.isMod() ? "Mod"
- : result.isRef() ? "Ref"
- : "NoModRef");
return result;
}
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 0fc5b4482bf3e..377f7ebe06750 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -501,10 +501,11 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
return;
SmallVector<RegionSuccessor> successors;
- auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op);
- if (!terminator)
- return;
- terminator.getSuccessorRegions(*operands, successors);
+ if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op))
+ terminator.getSuccessorRegions(*operands, successors);
+ else
+ branch.getSuccessorRegions(op->getParentRegion(), successors);
+
visitRegionBranchEdges(branch, op, successors);
}
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 0682e5f26785a..daa3db55b2852 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -588,9 +588,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
// flow, propagate the lattice back along the control flow edge.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
LDBG() << " Exit block of region branch operation";
- auto terminator =
- cast<RegionBranchTerminatorOpInterface>(block->getTerminator());
- visitRegionBranchOperation(point, branch, terminator, before);
+ visitRegionBranchOperation(point, branch, block->getParent(), before);
return;
}
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 8e63ae86753b4..0d2e2ed85549d 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -130,7 +130,7 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
// The results of a region branch operation are determined by control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
visitRegionSuccessors(getProgramPointAfter(branch), branch,
- /*successor=*/{branch, branch->getResults()},
+ /*successor=*/RegionBranchPoint::parent(),
resultLattices);
return success();
}
@@ -279,7 +279,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation(
void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
ProgramPoint *point, RegionBranchOpInterface branch,
- RegionSuccessor successor, ArrayRef<AbstractSparseLattice *> lattices) {
+ RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
assert(predecessors->allPredecessorsKnown() &&
"unexpected unresolved region successors");
@@ -314,7 +314,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
visitNonControlFlowArgumentsImpl(
branch,
RegionSuccessor(
- branch, branch->getResults().slice(firstIndex, inputs.size())),
+ branch->getResults().slice(firstIndex, inputs.size())),
lattices, firstIndex);
} else {
if (!inputs.empty())
diff --git a/mlir/lib/Analysis/SliceWalk.cpp b/mlir/lib/Analysis/SliceWalk.cpp
index 863f260cd4b6a..817d71a3452ca 100644
--- a/mlir/lib/Analysis/SliceWalk.cpp
+++ b/mlir/lib/Analysis/SliceWalk.cpp
@@ -114,7 +114,7 @@ mlir::getControlFlowPredecessors(Value value) {
if (!regionOp)
return std::nullopt;
// Add the control flow predecessor operands to the work list.
- RegionSuccessor region(regionOp, regionOp->getResults());
+ RegionSuccessor region(regionOp->getResults());
SmallVector<Value> predecessorOperands = getRegionPredecessorOperands(
regionOp, region, opResult.getResultNumber());
return predecessorOperands;
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 0c3592124cdec..e0a53cd52f143 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2716,9 +2716,8 @@ LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
return success(folded);
}
-OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
- assert((successor.isParent() || successor.getSuccessor() == &getRegion()) &&
- "invalid region point");
+OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert((point.isParent() || point == getRegion()) && "invalid region point");
// The initial operands map to the loop arguments after the induction
// variable or are forwarded to the results when the trip count is zero.
@@ -2727,41 +2726,34 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
void AffineForOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- assert((point.isParent() ||
- point.getTerminatorPredecessorOrNull()->getParentRegion() ==
- &getRegion()) &&
- "expected loop region");
+ assert((point.isParent() || point == getRegion()) && "expected loop region");
// The loop may typically branch back to its body or to the parent operation.
// If the predecessor is the parent op and the trip count is known to be at
// least one, branch into the body using the iterator arguments. And in cases
// we know the trip count is zero, it can only branch back to its parent.
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
- if (tripCount.has_value()) {
- if (!point.isParent()) {
- // From the loop body, if the trip count is one, we can only branch back
- // to the parent.
- if (tripCount == 1) {
- regions.push_back(RegionSuccessor(getOperation(), getResults()));
- return;
- }
- if (tripCount == 0)
- return;
- } else {
- if (tripCount.value() > 0) {
- regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
- return;
- }
- if (tripCount.value() == 0) {
- regions.push_back(RegionSuccessor(getOperation(), getResults()));
- return;
- }
+ if (point.isParent() && tripCount.has_value()) {
+ if (tripCount.value() > 0) {
+ regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+ return;
+ }
+ if (tripCount.value() == 0) {
+ regions.push_back(RegionSuccessor(getResults()));
+ return;
}
}
+ // From the loop body, if the trip count is one, we can only branch back to
+ // the parent.
+ if (!point.isParent() && tripCount == 1) {
+ regions.push_back(RegionSuccessor(getResults()));
+ return;
+ }
+
// In all other cases, the loop may branch back to itself or the parent
// operation.
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
- regions.push_back(RegionSuccessor(getOperation(), getResults()));
+ regions.push_back(RegionSuccessor(getResults()));
}
AffineBound AffineForOp::getLowerBound() {
@@ -3150,7 +3142,7 @@ void AffineIfOp::getSuccessorRegions(
RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
// If the "else" region is empty, branch bach into parent.
if (getElseRegion().empty()) {
- regions.push_back(RegionSuccessor(getOperation(), getResults()));
+ regions.push_back(getResults());
} else {
regions.push_back(
RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
@@ -3160,7 +3152,7 @@ void AffineIfOp::getSuccessorRegions(
// If the predecessor is the `else`/`then` region, then branching into parent
// op is valid.
- regions.push_back(RegionSuccessor(getOperation(), getResults()));
+ regions.push_back(RegionSuccessor(getResults()));
}
LogicalResult AffineIfOp::verify() {
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 8e4a49df76b52..dc7b07d911c17 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -36,9 +36,8 @@ void AsyncDialect::initialize() {
constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
-OperandRange ExecuteOp::getEntrySuccessorOperands(RegionSuccessor successor) {
- assert(successor.getSuccessor() == &getBodyRegion() &&
- "invalid region index");
+OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(point == getBodyRegion() && "invalid region index");
return getBodyOperands();
}
@@ -54,10 +53,8 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The `body` region branch back to the parent operation.
- if (!point.isParent() &&
- point.getTerminatorPredecessorOrNull()->getParentRegion() ==
- &getBodyRegion()) {
- regions.push_back(RegionSuccessor(getOperation(), getBodyResults()));
+ if (point == getBodyRegion()) {
+ regions.push_back(RegionSuccessor(getBodyResults()));
return;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index 36a759c279eb7..b593ccab060c7 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -562,11 +562,8 @@ LogicalResult
BufferDeallocation::updateFunctionSignature(FunctionOpInterface op) {
SmallVector<TypeRange> returnOperandTypes(llvm::map_range(
op.getFunctionBody().getOps<RegionBranchTerminatorOpInterface>(),
- [&](RegionBranchTerminatorOpInterface branchOp) {
- return branchOp
- .getSuccessorOperands(RegionSuccessor(
- op.getOperation(), op.getOperation()->getResults()))
- .getTypes();
+ [](RegionBranchTerminatorOpInterface op) {
+ return op.getSuccessorOperands(RegionBranchPoint::parent()).getTypes();
}));
if (!llvm::all_equal(returnOperandTypes))
return op->emitError(
@@ -945,8 +942,8 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
// about, but we would need to check how many successors there are and under
// which condition they are taken, etc.
- MutableOperandRange operands = op.getMutableSuccessorOperands(
- RegionSuccessor(op.getOperation(), op.getOperation()->getResults()));
+ MutableOperandRange operands =
+ op.getMutableSuccessorOperands(RegionBranchPoint::parent());
SmallVector<Value> updatedOwnerships;
auto result = deallocation_impl::insertDeallocOpForReturnLike(
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 0992ce14b4afb..4754f0bfe895e 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -845,8 +845,7 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
// The `then` and the `else` region branch back to the parent operation.
if (!point.isParent()) {
- regions.push_back(
- RegionSuccessor(getOperation(), getOperation()->getResults()));
+ regions.push_back(RegionSuccessor());
return;
}
@@ -855,8 +854,7 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point,
// Don't consider the else region if it is empty.
Region *elseRegion = &this->getElseRegion();
if (elseRegion->empty())
- regions.push_back(
- RegionSuccessor(getOperation(), getOperation()->getResults()));
+ regions.push_back(RegionSuccessor());
else
regions.push_back(RegionSuccessor(elseRegion));
}
@@ -873,7 +871,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
if (!getElseRegion().empty())
regions.emplace_back(&getElseRegion());
else
- regions.emplace_back(getOperation(), getOperation()->getResults());
+ regions.emplace_back();
}
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 6c6d8d2bad55d..b5f8ddaadacdf 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2399,7 +2399,7 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
void WarpExecuteOnLane0Op::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getOperation(), getResults()));
+ regions.push_back(RegionSuccessor(getResults()));
return;
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1c21a2f270da6..c551fba93e367 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -405,7 +405,7 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
void AllocaScopeOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getOperation(), getResults()));
+ regions.push_back(RegionSuccessor(getResults()));
return;
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 2946b53c8cb36..1ab01d86bcd10 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -397,7 +397,7 @@ void ExecuteRegionOp::getSuccessorRegions(
}
// Otherwise, the region branches back to the parent operation.
- regions.push_back(RegionSuccessor(getOperation(), getResults()));
+ regions.push_back(RegionSuccessor(getResults()));
}
//===----------------------------------------------------------------------===//
@@ -405,11 +405,10 @@ void ExecuteRegionOp::getSuccessorRegions(
//===----------------------------------------------------------------------===//
MutableOperandRange
-ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
- assert(
- (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) &&
- "condition op can only exit the loop or branch to the after"
- "region");
+ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
+ assert((point.isParent() || point == getParentOp().getAfter()) &&
+ "condition op can only exit the loop or branch to the after"
+ "region");
// Pass all operands except the condition to the successor region.
return getArgsMutable();
}
@@ -427,7 +426,7 @@ void ConditionOp::getSuccessorRegions(
regions.emplace_back(&whileOp.getAfter(),
whileOp.getAfter().getArguments());
if (!boolAttr || !boolAttr.getValue())
- regions.emplace_back(whileOp.getOperation(), whileOp.getResults());
+ regions.emplace_back(whileOp.getResults());
}
//===----------------------------------------------------------------------===//
@@ -750,7 +749,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
return dyn_cast_or_null<ForOp>(containingOp);
}
-OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
return getInitArgs();
}
@@ -760,7 +759,7 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
// back into the operation itself. It is possible for loop not to enter the
// body.
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
- regions.push_back(RegionSuccessor(getOperation(), getResults()));
+ regions.push_back(RegionSuccessor(getResults()));
}
SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
@@ -2054,10 +2053,9 @@ void ForallOp::getSuccessorRegions(RegionBranchPoint point,
// parallel by multiple threads. We should not expect to branch back into
// the forall body after the region's execution is complete.
if (point.isParent())
- regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+ regions.push_back(RegionSuccessor(&getRegion()));
else
- regions.push_back(
- RegionSuccessor(getOperation(), getOperation()->getResults()));
+ regions.push_back(RegionSuccessor());
}
//===----------------------------------------------------------------------===//
@@ -2335,10 +2333,9 @@ void IfOp::print(OpAsmPrinter &p) {
void IfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
- // The `then` and the `else` region branch back to the parent operation or one
- // of the recursive parent operations (early exit case).
+ // The `then` and the `else` region branch back to the parent operation.
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getOperation(), getResults()));
+ regions.push_back(RegionSuccessor(getResults()));
return;
}
@@ -2347,8 +2344,7 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point,
// Don't consider the else region if it is empty.
Region *elseRegion = &this->getElseRegion();
if (elseRegion->empty())
- regions.push_back(
- RegionSuccessor(getOperation(), getOperation()->getResults()));
+ regions.push_back(RegionSuccessor());
else
regions.push_back(RegionSuccessor(elseRegion));
}
@@ -2365,7 +2361,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
if (!getElseRegion().empty())
regions.emplace_back(&getElseRegion());
else
- regions.emplace_back(getOperation(), getResults());
+ regions.emplace_back(getResults());
}
}
@@ -3389,8 +3385,7 @@ void ParallelOp::getSuccessorRegions(
// back into the operation itself. It is possible for loop not to enter the
// body.
regions.push_back(RegionSuccessor(&getRegion()));
- regions.push_back(RegionSuccessor(
- getOperation(), ResultRange{getResults().end(), getResults().end()}));
+ regions.push_back(RegionSuccessor());
}
//===----------------------------------------------------------------------===//
@@ -3436,7 +3431,7 @@ LogicalResult ReduceOp::verifyRegions() {
}
MutableOperandRange
-ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
+ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) {
// No operands are forwarded to the next iteration.
return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
}
@@ -3519,8 +3514,8 @@ Block::BlockArgListType WhileOp::getRegionIterArgs() {
return getBeforeArguments();
}
-OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
- assert(successor.getSuccessor() == &getBefore() &&
+OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(point == getBefore() &&
"WhileOp is expected to branch only to the first region");
return getInits();
}
@@ -3533,18 +3528,15 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point,
return;
}
- assert(llvm::is_contained(
- {&getAfter(), &getBefore()},
- point.getTerminatorPredecessorOrNull()->getParentRegion()) &&
+ assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
"there are only two regions in a WhileOp");
// The body region always branches back to the condition region.
- if (point.getTerminatorPredecessorOrNull()->getParentRegion() ==
- &getAfter()) {
+ if (point == getAfter()) {
regions.emplace_back(&getBefore(), getBefore().getArguments());
return;
}
- regions.emplace_back(getOperation(), getResults());
+ regions.emplace_back(getResults());
regions.emplace_back(&getAfter(), getAfter().getArguments());
}
@@ -4453,7 +4445,7 @@ void IndexSwitchOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
// All regions branch back to the parent op.
if (!point.isParent()) {
- successors.emplace_back(getOperation(), getResults());
+ successors.emplace_back(getResults());
return;
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index ddcbda86cf1f3..ae52af5009dc9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -23,6 +23,7 @@ namespace mlir {
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
+using namespace llvm;
using namespace mlir;
using scf::ForOp;
using scf::WhileOp;
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
index 00bef707fadd3..a2f03f1e1056e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
@@ -21,6 +21,7 @@ namespace mlir {
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
+using namespace llvm;
using namespace mlir;
using scf::LoopNest;
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index f0f22e5ef4a83..5ba828918c22a 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -346,7 +346,7 @@ void AssumingOp::getSuccessorRegions(
// parent, so return the correct RegionSuccessor purely based on the index
// being None or 0.
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getOperation(), getResults()));
+ regions.push_back(RegionSuccessor(getResults()));
return;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 3962e3e84dd31..1a9d9e158ee75 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2597,7 +2597,7 @@ std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
-OperandRange IterateOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
return getInitArgs();
}
@@ -2607,7 +2607,7 @@ void IterateOp::getSuccessorRegions(RegionBranchPoint point,
// or back into the operation itself.
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
// It is possible for loop not to enter the body.
- regions.push_back(RegionSuccessor(getOperation(), getResults()));
+ regions.push_back(RegionSuccessor(getResults()));
}
void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 062606e7e10b6..365afab3764c8 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -96,9 +96,9 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
// AlternativesOp
//===----------------------------------------------------------------------===//
-OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
- RegionSuccessor successor) {
- if (!successor.isParent() && getOperation()->getNumOperands() == 1)
+OperandRange
+transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ if (!point.isParent() && getOperation()->getNumOperands() == 1)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
getOperation()->operand_end());
@@ -107,18 +107,15 @@ OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
void transform::AlternativesOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
for (Region &alternative : llvm::drop_begin(
- getAlternatives(), point.isParent()
- ? 0
- : point.getTerminatorPredecessorOrNull()
- ->getParentRegion()
- ->getRegionNumber() +
- 1)) {
+ getAlternatives(),
+ point.isParent() ? 0
+ : point.getRegionOrNull()->getRegionNumber() + 1)) {
regions.emplace_back(&alternative, !getOperands().empty()
? alternative.getArguments()
: Block::BlockArgListType());
}
if (!point.isParent())
- regions.emplace_back(getOperation(), getOperation()->getResults());
+ regions.emplace_back(getOperation()->getResults());
}
void transform::AlternativesOp::getRegionInvocationBounds(
@@ -1743,18 +1740,16 @@ void transform::ForeachOp::getSuccessorRegions(
}
// Branch back to the region or the parent.
- assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
- &getBody() &&
- "unexpected region index");
+ assert(point == getBody() && "unexpected region index");
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
- regions.emplace_back(getOperation(), getOperation()->getResults());
+ regions.emplace_back();
}
OperandRange
-transform::ForeachOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
// Each block argument handle is mapped to a subset (one op to be precise)
// of the payload of the corresponding `targets` operand of ForeachOp.
- assert(successor.getSuccessor() == &getBody() && "unexpected region index");
+ assert(point == getBody() && "unexpected region index");
return getOperation()->getOperands();
}
@@ -2953,8 +2948,8 @@ void transform::SequenceOp::getEffects(
}
OperandRange
-transform::SequenceOp::getEntrySuccessorOperands(RegionSuccessor successor) {
- assert(successor.getSuccessor() == &getBody() && "unexpected region index");
+transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(point == getBody() && "unexpected region index");
if (getOperation()->getNumOperands() > 0)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
@@ -2971,10 +2966,8 @@ void transform::SequenceOp::getSuccessorRegions(
return;
}
- assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
- &getBody() &&
- "unexpected region index");
- regions.emplace_back(getOperation(), getOperation()->getResults());
+ assert(point == getBody() && "unexpected region index");
+ regions.emplace_back(getOperation()->getResults());
}
void transform::SequenceOp::getRegionInvocationBounds(
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
index f727118f3f9a0..c627158e999ed 100644
--- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
@@ -9,7 +9,6 @@
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
-#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
@@ -113,7 +112,7 @@ static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
}
OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
- RegionSuccessor successor) {
+ RegionBranchPoint point) {
// No operands will be forwarded to the region(s).
return getOperands().slice(0, 0);
}
@@ -129,7 +128,7 @@ void transform::tune::AlternativesOp::getSuccessorRegions(
for (Region &alternative : getAlternatives())
regions.emplace_back(&alternative, Block::BlockArgListType());
else
- regions.emplace_back(getOperation(), getOperation()->getResults());
+ regions.emplace_back(getOperation()->getResults());
}
void transform::tune::AlternativesOp::getRegionInvocationBounds(
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index f4c9242ed3479..776b5c6588c71 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -138,10 +138,6 @@ Diagnostic &Diagnostic::operator<<(Operation &op) {
return appendOp(op, OpPrintingFlags());
}
-Diagnostic &Diagnostic::operator<<(OpWithFlags op) {
- return appendOp(*op.getOperation(), op.flags());
-}
-
Diagnostic &Diagnostic::appendOp(Operation &op, const OpPrintingFlags &flags) {
std::string str;
llvm::raw_string_ostream os(str);
diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp
index 15a941f380225..46b6298076d48 100644
--- a/mlir/lib/IR/Region.cpp
+++ b/mlir/lib/IR/Region.cpp
@@ -253,21 +253,6 @@ void Region::OpIterator::skipOverBlocksWithNoOps() {
operation = block->begin();
}
-llvm::raw_ostream &mlir::operator<<(llvm::raw_ostream &os, Region ®ion) {
- if (!region.getParentOp()) {
- os << "Region has no parent op";
- } else {
- os << "Region #" << region.getRegionNumber() << " in operation "
- << region.getParentOp()->getName();
- }
- for (auto it : llvm::enumerate(region.getBlocks())) {
- os << "\n Block #" << it.index() << ":";
- for (Operation &op : it.value().getOperations())
- os << "\n " << OpWithFlags(&op, OpPrintingFlags().skipRegions());
- }
- return os;
-}
-
//===----------------------------------------------------------------------===//
// RegionRange
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 1e56810ff7aaf..ca3f7666dba8a 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -9,9 +9,7 @@
#include <utility>
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
-#include "llvm/Support/DebugLog.h"
using namespace mlir;
@@ -40,31 +38,20 @@ SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
std::optional<BlockArgument>
detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
unsigned operandIndex, Block *successor) {
- LDBG() << "Getting branch successor argument for operand index "
- << operandIndex << " in successor block";
-
OperandRange forwardedOperands = operands.getForwardedOperands();
// Check that the operands are valid.
- if (forwardedOperands.empty()) {
- LDBG() << "No forwarded operands, returning nullopt";
+ if (forwardedOperands.empty())
return std::nullopt;
- }
// Check to ensure that this operand is within the range.
unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
if (operandIndex < operandsStart ||
- operandIndex >= (operandsStart + forwardedOperands.size())) {
- LDBG() << "Operand index " << operandIndex << " out of range ["
- << operandsStart << ", "
- << (operandsStart + forwardedOperands.size())
- << "), returning nullopt";
+ operandIndex >= (operandsStart + forwardedOperands.size()))
return std::nullopt;
- }
// Index the successor.
unsigned argIndex =
operands.getProducedOperandCount() + operandIndex - operandsStart;
- LDBG() << "Computed argument index " << argIndex << " for successor block";
return successor->getArgument(argIndex);
}
@@ -72,15 +59,9 @@ detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
LogicalResult
detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
const SuccessorOperands &operands) {
- LDBG() << "Verifying branch successor operands for successor #" << succNo
- << " in operation " << op->getName();
-
// Check the count.
unsigned operandCount = operands.size();
Block *destBB = op->getSuccessor(succNo);
- LDBG() << "Branch has " << operandCount << " operands, target block has "
- << destBB->getNumArguments() << " arguments";
-
if (operandCount != destBB->getNumArguments())
return op->emitError() << "branch has " << operandCount
<< " operands for successor #" << succNo
@@ -88,22 +69,13 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
<< destBB->getNumArguments();
// Check the types.
- LDBG() << "Checking type compatibility for "
- << (operandCount - operands.getProducedOperandCount())
- << " forwarded operands";
for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
++i) {
- Type operandType = operands[i].getType();
- Type argType = destBB->getArgument(i).getType();
- LDBG() << "Checking type compatibility: operand type " << operandType
- << " vs argument type " << argType;
-
- if (!cast<BranchOpInterface>(op).areTypesCompatible(operandType, argType))
+ if (!cast<BranchOpInterface>(op).areTypesCompatible(
+ operands[i].getType(), destBB->getArgument(i).getType()))
return op->emitError() << "type mismatch for bb argument #" << i
<< " of successor #" << succNo;
}
-
- LDBG() << "Branch successor operand verification successful";
return success();
}
@@ -154,15 +126,15 @@ LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
RegionBranchPoint sourceNo,
- RegionSuccessor succRegionNo) {
+ RegionBranchPoint succRegionNo) {
diag << "from ";
- if (Operation *op = sourceNo.getTerminatorPredecessorOrNull())
- diag << "Operation " << op->getName();
+ if (Region *region = sourceNo.getRegionOrNull())
+ diag << "Region #" << region->getRegionNumber();
else
diag << "parent operands";
diag << " to ";
- if (Region *region = succRegionNo.getSuccessor())
+ if (Region *region = succRegionNo.getRegionOrNull())
diag << "Region #" << region->getRegionNumber();
else
diag << "parent results";
@@ -173,12 +145,13 @@ static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
/// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
/// types of the inputs that flow to a successor region.
static LogicalResult
-verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp,
- RegionBranchPoint sourcePoint,
- function_ref<FailureOr<TypeRange>(RegionSuccessor)>
+verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
+ function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
getInputsTypesForRegion) {
+ auto regionInterface = cast<RegionBranchOpInterface>(op);
+
SmallVector<RegionSuccessor, 2> successors;
- branchOp.getSuccessorRegions(sourcePoint, successors);
+ regionInterface.getSuccessorRegions(sourcePoint, successors);
for (RegionSuccessor &succ : successors) {
FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
@@ -187,14 +160,10 @@ verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp,
TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
if (sourceTypes->size() != succInputsTypes.size()) {
- InFlightDiagnostic diag =
- branchOp->emitOpError("region control flow edge ");
- std::string succStr;
- llvm::raw_string_ostream os(succStr);
- os << succ;
+ InFlightDiagnostic diag = op->emitOpError("region control flow edge ");
return printRegionEdgeName(diag, sourcePoint, succ)
<< ": source has " << sourceTypes->size()
- << " operands, but target successor " << os.str() << " needs "
+ << " operands, but target successor needs "
<< succInputsTypes.size();
}
@@ -202,10 +171,8 @@ verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp,
llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
Type sourceType = std::get<0>(typesIdx.value());
Type inputType = std::get<1>(typesIdx.value());
-
- if (!branchOp.areTypesCompatible(sourceType, inputType)) {
- InFlightDiagnostic diag =
- branchOp->emitOpError("along control flow edge ");
+ if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
+ InFlightDiagnostic diag = op->emitOpError("along control flow edge ");
return printRegionEdgeName(diag, sourcePoint, succ)
<< ": source type #" << typesIdx.index() << " " << sourceType
<< " should match input type #" << typesIdx.index() << " "
@@ -213,7 +180,6 @@ verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp,
}
}
}
-
return success();
}
@@ -221,18 +187,34 @@ verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp,
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
- auto inputTypesFromParent = [&](RegionSuccessor successor) -> TypeRange {
- return regionInterface.getEntrySuccessorOperands(successor).getTypes();
+ auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
+ return regionInterface.getEntrySuccessorOperands(point).getTypes();
};
// Verify types along control flow edges originating from the parent.
- if (failed(verifyTypesAlongAllEdges(
- regionInterface, RegionBranchPoint::parent(), inputTypesFromParent)))
+ if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(),
+ inputTypesFromParent)))
return failure();
+ auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
+ if (lhs.size() != rhs.size())
+ return false;
+ for (auto types : llvm::zip(lhs, rhs)) {
+ if (!regionInterface.areTypesCompatible(std::get<0>(types),
+ std::get<1>(types))) {
+ return false;
+ }
+ }
+ return true;
+ };
+
// Verify types along control flow edges originating from each region.
for (Region ®ion : op->getRegions()) {
- // Collect all return-like terminators in the region.
+
+ // Since there can be multiple terminators implementing the
+ // `RegionBranchTerminatorOpInterface`, all should have the same operand
+ // types when passing them to the same region.
+
SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
for (Block &block : region)
if (!block.empty())
@@ -245,20 +227,33 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
if (regionReturnOps.empty())
continue;
- // Verify types along control flow edges originating from each return-like
- // terminator.
- for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
-
- auto inputTypesForRegion =
- [&](RegionSuccessor successor) -> FailureOr<TypeRange> {
- OperandRange terminatorOperands =
- regionReturnOp.getSuccessorOperands(successor);
- return TypeRange(terminatorOperands.getTypes());
- };
- if (failed(verifyTypesAlongAllEdges(regionInterface, regionReturnOp,
- inputTypesForRegion)))
- return failure();
- }
+ auto inputTypesForRegion =
+ [&](RegionBranchPoint point) -> FailureOr<TypeRange> {
+ std::optional<OperandRange> regionReturnOperands;
+ for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
+ auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);
+
+ if (!regionReturnOperands) {
+ regionReturnOperands = terminatorOperands;
+ continue;
+ }
+
+ // Found more than one ReturnLike terminator. Make sure the operand
+ // types match with the first one.
+ if (!areTypesCompatible(regionReturnOperands->getTypes(),
+ terminatorOperands.getTypes())) {
+ InFlightDiagnostic diag = op->emitOpError("along control flow edge");
+ return printRegionEdgeName(diag, region, point)
+ << " operands mismatch between return-like terminators";
+ }
+ }
+
+ // All successors get the same set of operand types.
+ return TypeRange(regionReturnOperands->getTypes());
+ };
+
+ if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
+ return failure();
}
return success();
@@ -277,74 +272,31 @@ using StopConditionFn = function_ref<bool(Region *, ArrayRef<bool> visited)>;
static bool traverseRegionGraph(Region *begin,
StopConditionFn stopConditionFn) {
auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
- LDBG() << "Starting region graph traversal from region #"
- << begin->getRegionNumber() << " in operation " << op->getName();
-
SmallVector<bool> visited(op->getNumRegions(), false);
visited[begin->getRegionNumber()] = true;
- LDBG() << "Initialized visited array with " << op->getNumRegions()
- << " regions";
// Retrieve all successors of the region and enqueue them in the worklist.
SmallVector<Region *> worklist;
auto enqueueAllSuccessors = [&](Region *region) {
- LDBG() << "Enqueuing successors for region #" << region->getRegionNumber();
- SmallVector<Attribute> operandAttributes(op->getNumOperands());
- for (Block &block : *region) {
- if (block.empty())
- continue;
- auto terminator =
- dyn_cast<RegionBranchTerminatorOpInterface>(block.back());
- if (!terminator)
- continue;
- SmallVector<RegionSuccessor> successors;
- operandAttributes.resize(terminator->getNumOperands());
- terminator.getSuccessorRegions(operandAttributes, successors);
- LDBG() << "Found " << successors.size()
- << " successors from terminator in block";
- for (RegionSuccessor successor : successors) {
- if (!successor.isParent()) {
- worklist.push_back(successor.getSuccessor());
- LDBG() << "Added region #"
- << successor.getSuccessor()->getRegionNumber()
- << " to worklist";
- } else {
- LDBG() << "Skipping parent successor";
- }
- }
- }
+ SmallVector<RegionSuccessor> successors;
+ op.getSuccessorRegions(region, successors);
+ for (RegionSuccessor successor : successors)
+ if (!successor.isParent())
+ worklist.push_back(successor.getSuccessor());
};
enqueueAllSuccessors(begin);
- LDBG() << "Initial worklist size: " << worklist.size();
// Process all regions in the worklist via DFS.
while (!worklist.empty()) {
Region *nextRegion = worklist.pop_back_val();
- LDBG() << "Processing region #" << nextRegion->getRegionNumber()
- << " from worklist (remaining: " << worklist.size() << ")";
-
- if (stopConditionFn(nextRegion, visited)) {
- LDBG() << "Stop condition met for region #"
- << nextRegion->getRegionNumber() << ", returning true";
+ if (stopConditionFn(nextRegion, visited))
return true;
- }
- llvm::dbgs() << "Region: " << nextRegion << "\n";
- if (!nextRegion->getParentOp()) {
- llvm::errs() << "Region " << *nextRegion << " has no parent op\n";
- return false;
- }
- if (visited[nextRegion->getRegionNumber()]) {
- LDBG() << "Region #" << nextRegion->getRegionNumber()
- << " already visited, skipping";
+ if (visited[nextRegion->getRegionNumber()])
continue;
- }
visited[nextRegion->getRegionNumber()] = true;
- LDBG() << "Marking region #" << nextRegion->getRegionNumber()
- << " as visited";
enqueueAllSuccessors(nextRegion);
}
- LDBG() << "Traversal completed, returning false";
return false;
}
@@ -370,26 +322,18 @@ static bool isRegionReachable(Region *begin, Region *r) {
/// mutually exclusive if they are not reachable from each other as per
/// RegionBranchOpInterface::getSuccessorRegions.
bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
- LDBG() << "Checking if operations are in mutually exclusive regions: "
- << a->getName() << " and " << b->getName();
-
assert(a && "expected non-empty operation");
assert(b && "expected non-empty operation");
auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
while (branchOp) {
- LDBG() << "Checking branch operation " << branchOp->getName();
-
// Check if b is inside branchOp. (We already know that a is.)
if (!branchOp->isProperAncestor(b)) {
- LDBG() << "Operation b is not inside branchOp, checking next ancestor";
// Check next enclosing RegionBranchOpInterface.
branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
continue;
}
- LDBG() << "Both operations are inside branchOp, finding their regions";
-
// b is contained in branchOp. Retrieve the regions in which `a` and `b`
// are contained.
Region *regionA = nullptr, *regionB = nullptr;
@@ -397,136 +341,63 @@ bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
if (r.findAncestorOpInRegion(*a)) {
assert(!regionA && "already found a region for a");
regionA = &r;
- LDBG() << "Found region #" << r.getRegionNumber() << " for operation a";
}
if (r.findAncestorOpInRegion(*b)) {
assert(!regionB && "already found a region for b");
regionB = &r;
- LDBG() << "Found region #" << r.getRegionNumber() << " for operation b";
}
}
assert(regionA && regionB && "could not find region of op");
- LDBG() << "Region A: #" << regionA->getRegionNumber() << ", Region B: #"
- << regionB->getRegionNumber();
-
// `a` and `b` are in mutually exclusive regions if both regions are
// distinct and neither region is reachable from the other region.
- bool regionsAreDistinct = (regionA != regionB);
- bool aNotReachableFromB = !isRegionReachable(regionA, regionB);
- bool bNotReachableFromA = !isRegionReachable(regionB, regionA);
-
- LDBG() << "Regions distinct: " << regionsAreDistinct
- << ", A not reachable from B: " << aNotReachableFromB
- << ", B not reachable from A: " << bNotReachableFromA;
-
- bool mutuallyExclusive =
- regionsAreDistinct && aNotReachableFromB && bNotReachableFromA;
- LDBG() << "Operations are mutually exclusive: " << mutuallyExclusive;
-
- return mutuallyExclusive;
+ return regionA != regionB && !isRegionReachable(regionA, regionB) &&
+ !isRegionReachable(regionB, regionA);
}
// Could not find a common RegionBranchOpInterface among a's and b's
// ancestors.
- LDBG() << "No common RegionBranchOpInterface found, operations are not "
- "mutually exclusive";
return false;
}
bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
- LDBG() << "Checking if region #" << index << " is repetitive in operation "
- << getOperation()->getName();
-
Region *region = &getOperation()->getRegion(index);
- bool isRepetitive = isRegionReachable(region, region);
-
- LDBG() << "Region #" << index << " is repetitive: " << isRepetitive;
- return isRepetitive;
+ return isRegionReachable(region, region);
}
bool RegionBranchOpInterface::hasLoop() {
- LDBG() << "Checking if operation " << getOperation()->getName()
- << " has loops";
-
SmallVector<RegionSuccessor> entryRegions;
getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
- LDBG() << "Found " << entryRegions.size() << " entry regions";
-
- for (RegionSuccessor successor : entryRegions) {
- if (!successor.isParent()) {
- LDBG() << "Checking entry region #"
- << successor.getSuccessor()->getRegionNumber() << " for loops";
-
- bool hasLoop =
- traverseRegionGraph(successor.getSuccessor(),
- [](Region *nextRegion, ArrayRef<bool> visited) {
- // Interrupt traversal if the region was already
- // visited.
- return visited[nextRegion->getRegionNumber()];
- });
-
- if (hasLoop) {
- LDBG() << "Found loop in entry region #"
- << successor.getSuccessor()->getRegionNumber();
- return true;
- }
- } else {
- LDBG() << "Skipping parent successor";
- }
- }
-
- LDBG() << "No loops found in operation";
+ for (RegionSuccessor successor : entryRegions)
+ if (!successor.isParent() &&
+ traverseRegionGraph(successor.getSuccessor(),
+ [](Region *nextRegion, ArrayRef<bool> visited) {
+ // Interrupt traversal if the region was already
+ // visited.
+ return visited[nextRegion->getRegionNumber()];
+ }))
+ return true;
return false;
}
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
- LDBG() << "Finding enclosing repetitive region for operation "
- << op->getName();
-
while (Region *region = op->getParentRegion()) {
- LDBG() << "Checking region #" << region->getRegionNumber()
- << " in operation " << region->getParentOp()->getName();
-
op = region->getParentOp();
- if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
- LDBG()
- << "Found RegionBranchOpInterface, checking if region is repetitive";
- if (branchOp.isRepetitiveRegion(region->getRegionNumber())) {
- LDBG() << "Found repetitive region #" << region->getRegionNumber();
+ if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
+ if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
return region;
- }
- } else {
- LDBG() << "Parent operation does not implement RegionBranchOpInterface";
- }
}
-
- LDBG() << "No enclosing repetitive region found";
return nullptr;
}
Region *mlir::getEnclosingRepetitiveRegion(Value value) {
- LDBG() << "Finding enclosing repetitive region for value";
-
Region *region = value.getParentRegion();
while (region) {
- LDBG() << "Checking region #" << region->getRegionNumber()
- << " in operation " << region->getParentOp()->getName();
-
Operation *op = region->getParentOp();
- if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
- LDBG()
- << "Found RegionBranchOpInterface, checking if region is repetitive";
- if (branchOp.isRepetitiveRegion(region->getRegionNumber())) {
- LDBG() << "Found repetitive region #" << region->getRegionNumber();
+ if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
+ if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
return region;
- }
- } else {
- LDBG() << "Parent operation does not implement RegionBranchOpInterface";
- }
region = op->getParentRegion();
}
-
- LDBG() << "No enclosing repetitive region found for value";
return nullptr;
}
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 41f3f9d76a3b1..e0c65b0e09774 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -432,7 +432,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// Return the successors of `region` if the latter is not null. Else return
// the successors of `regionBranchOp`.
- auto getSuccessors = [&](RegionBranchPoint point) {
+ auto getSuccessors = [&](Region *region = nullptr) {
+ auto point = region ? region : RegionBranchPoint::parent();
SmallVector<RegionSuccessor> successors;
regionBranchOp.getSuccessorRegions(point, successors);
return successors;
@@ -455,8 +456,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// `nonForwardedOperands`.
auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint::parent())) {
+ for (const RegionSuccessor &successor : getSuccessors()) {
for (OpOperand *opOperand : getForwardedOpOperands(successor))
nonForwardedOperands.reset(opOperand->getOperandNumber());
}
@@ -469,13 +469,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
for (Region ®ion : regionBranchOp->getRegions()) {
if (region.empty())
continue;
- // TODO: this isn't correct in face of multiple terminators.
Operation *terminator = region.front().getTerminator();
nonForwardedRets[terminator] =
BitVector(terminator->getNumOperands(), true);
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator)))) {
+ for (const RegionSuccessor &successor : getSuccessors(®ion)) {
for (OpOperand *opOperand :
getForwardedOpOperands(successor, terminator))
nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
@@ -492,13 +489,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
Operation *terminator =
region ? region->front().getTerminator() : nullptr;
- RegionBranchPoint point =
- terminator
- ? RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator))
- : RegionBranchPoint::parent();
- for (const RegionSuccessor &successor : getSuccessors(point)) {
+ for (const RegionSuccessor &successor : getSuccessors(region)) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor, terminator),
@@ -525,8 +517,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
resultsOrArgsToKeepChanged = false;
// Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint::parent())) {
+ for (const RegionSuccessor &successor : getSuccessors()) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor),
@@ -560,9 +551,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
Operation *terminator = region.front().getTerminator();
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator)))) {
+ for (const RegionSuccessor &successor : getSuccessors(®ion)) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor, terminator),
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 3f481ad5dbba7..37fc86b18e7f0 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -373,7 +373,7 @@ func.func @reduceReturn_not_inside_reduce(%arg0 : f32) {
func.func @std_if_incorrect_yield(%arg0: i1, %arg1: f32)
{
- // expected-error at +1 {{region control flow edge from Operation scf.yield to parent results: source has 1 operands, but target successor <to parent> needs 2}}
+ // expected-error at +1 {{region control flow edge from Region #0 to parent results: source has 1 operands, but target successor needs 2}}
%x, %y = scf.if %arg0 -> (f32, f32) {
%0 = arith.addf %arg1, %arg1 : f32
scf.yield %0 : f32
@@ -544,7 +544,7 @@ func.func @while_invalid_terminator() {
func.func @while_cross_region_type_mismatch() {
%true = arith.constant true
- // expected-error at +1 {{region control flow edge from Operation scf.condition to Region #1: source has 0 operands, but target successor <to region #1 with 1 inputs> needs 1}}
+ // expected-error at +1 {{'scf.while' op region control flow edge from Region #0 to Region #1: source has 0 operands, but target successor needs 1}}
scf.while : () -> () {
scf.condition(%true)
} do {
@@ -557,7 +557,7 @@ func.func @while_cross_region_type_mismatch() {
func.func @while_cross_region_type_mismatch() {
%true = arith.constant true
- // expected-error at +1 {{along control flow edge from Operation scf.condition to Region #1: source type #0 'i1' should match input type #0 'i32'}}
+ // expected-error at +1 {{'scf.while' op along control flow edge from Region #0 to Region #1: source type #0 'i1' should match input type #0 'i32'}}
%0 = scf.while : () -> (i1) {
scf.condition(%true) %true : i1
} do {
@@ -570,7 +570,7 @@ func.func @while_cross_region_type_mismatch() {
func.func @while_result_type_mismatch() {
%true = arith.constant true
- // expected-error at +1 {{region control flow edge from Operation scf.condition to parent results: source has 1 operands, but target successor <to parent> needs 0}}
+ // expected-error at +1 {{'scf.while' op region control flow edge from Region #0 to parent results: source has 1 operands, but target successor needs 0}}
scf.while : () -> () {
scf.condition(%true) %true : i1
} do {
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
index 7a7a58384fbb8..eb0d9801e7d3f 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
@@ -66,7 +66,7 @@ class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
RegionBranchPoint regionFrom,
- RegionSuccessor regionTo,
+ RegionBranchPoint regionTo,
const NextAccess &after,
NextAccess *before) override;
@@ -240,7 +240,7 @@ void NextAccessAnalysis::visitCallControlFlowTransfer(
void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionSuccessor regionTo, const NextAccess &after, NextAccess *before) {
+ RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) {
LDBG() << "visitRegionBranchControlFlowTransfer: "
<< OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions());
LDBG() << " regionFrom: " << (regionFrom.isParent() ? "parent" : "region");
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 4d4ec02546bc7..b211e243f234c 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -633,9 +633,8 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
parser.getCurrentLocation(), result.operands);
}
-OperandRange RegionIfOp::getEntrySuccessorOperands(RegionSuccessor successor) {
- assert(llvm::is_contained({&getThenRegion(), &getElseRegion()},
- successor.getSuccessor()) &&
+OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
"invalid region index");
return getOperands();
}
@@ -644,11 +643,10 @@ void RegionIfOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
// We always branch to the join region.
if (!point.isParent()) {
- if (point.getTerminatorPredecessorOrNull()->getParentRegion() !=
- &getJoinRegion())
+ if (point != getJoinRegion())
regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
else
- regions.push_back(RegionSuccessor(getOperation(), getResults()));
+ regions.push_back(RegionSuccessor(getResults()));
return;
}
@@ -675,7 +673,7 @@ void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
if (point.isParent())
regions.emplace_back(&getRegion());
else
- regions.emplace_back(getOperation(), getResults());
+ regions.emplace_back(getResults());
}
void AnyCondOp::getRegionInvocationBounds(
@@ -1109,11 +1107,11 @@ void LoopBlockOp::getSuccessorRegions(
if (point.isParent())
return;
- regions.emplace_back(getOperation(), getOperation()->getResults());
+ regions.emplace_back((*this)->getResults());
}
-OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionSuccessor successor) {
- assert(successor.getSuccessor() == &getBody());
+OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(point == getBody());
return MutableOperandRange(getInitMutable());
}
@@ -1122,8 +1120,8 @@ OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionSuccessor successor) {
//===----------------------------------------------------------------------===//
MutableOperandRange
-LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionSuccessor successor) {
- if (successor.isParent())
+LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
+ if (point.isParent())
return getExitArgMutable();
return getNextIterArgMutable();
}
@@ -1215,7 +1213,7 @@ void TestStoreWithARegion::getSuccessorRegions(
if (point.isParent())
regions.emplace_back(&getBody(), getBody().front().getArguments());
else
- regions.emplace_back(getOperation(), getOperation()->getResults());
+ regions.emplace_back();
}
//===----------------------------------------------------------------------===//
@@ -1229,7 +1227,7 @@ void TestStoreWithALoopRegion::getSuccessorRegions(
// enter the body.
regions.emplace_back(
RegionSuccessor(&getBody(), getBody().front().getArguments()));
- regions.emplace_back(getOperation(), getOperation()->getResults());
+ regions.emplace_back();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index a3430ba49a291..05a33cf1afd94 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2581,7 +2581,7 @@ def LoopBlockTerminatorOp : TEST_Op<"loop_block_term",
def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [
NoTerminator,
- DeclareOpInterfaceMethods<RegionBranchOpInterface>
+ DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorRegions"]>
]> {
let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases);
let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
diff --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
index 2e6950fca6be2..f1aae15393fd3 100644
--- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
@@ -13,24 +13,17 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Parser/Parser.h"
-#include "llvm/Support/DebugLog.h"
#include <gtest/gtest.h>
using namespace mlir;
/// A dummy op that is also a terminator.
-struct DummyOp : public Op<DummyOp, OpTrait::IsTerminator, OpTrait::ZeroResults,
- OpTrait::ZeroSuccessors,
- RegionBranchTerminatorOpInterface::Trait> {
+struct DummyOp : public Op<DummyOp, OpTrait::IsTerminator> {
using Op::Op;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
static StringRef getOperationName() { return "cftest.dummy_op"; }
-
- MutableOperandRange getMutableSuccessorOperands(RegionSuccessor point) {
- return MutableOperandRange(getOperation(), 0, 0);
- }
};
/// All regions of this op are mutually exclusive.
@@ -46,8 +39,6 @@ struct MutuallyExclusiveRegionsOp
// Regions have no successors.
void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {}
- using RegionBranchOpInterface::Trait<
- MutuallyExclusiveRegionsOp>::getSuccessorRegions;
};
/// All regions of this op call each other in a large circle.
@@ -62,18 +53,13 @@ struct LoopRegionsOp
void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
- if (point.getTerminatorPredecessorOrNull()) {
- Region *region =
- point.getTerminatorPredecessorOrNull()->getParentRegion();
- if (region == &(*this)->getRegion(1))
+ if (Region *region = point.getRegionOrNull()) {
+ if (point == (*this)->getRegion(1))
// This region also branches back to the parent.
- regions.push_back(
- RegionSuccessor(getOperation()->getParentOp(),
- getOperation()->getParentOp()->getResults()));
+ regions.push_back(RegionSuccessor());
regions.push_back(RegionSuccessor(region));
}
}
- using RegionBranchOpInterface::Trait<LoopRegionsOp>::getSuccessorRegions;
};
/// Each region branches back it itself or the parent.
@@ -89,17 +75,11 @@ struct DoubleLoopRegionsOp
void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
- if (point.getTerminatorPredecessorOrNull()) {
- Region *region =
- point.getTerminatorPredecessorOrNull()->getParentRegion();
- regions.push_back(
- RegionSuccessor(getOperation()->getParentOp(),
- getOperation()->getParentOp()->getResults()));
+ if (Region *region = point.getRegionOrNull()) {
+ regions.push_back(RegionSuccessor());
regions.push_back(RegionSuccessor(region));
}
}
- using RegionBranchOpInterface::Trait<
- DoubleLoopRegionsOp>::getSuccessorRegions;
};
/// Regions are executed sequentially.
@@ -113,15 +93,11 @@ struct SequentialRegionsOp
// Region 0 has Region 1 as a successor.
void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
- if (point.getTerminatorPredecessorOrNull() &&
- point.getTerminatorPredecessorOrNull()->getParentRegion() ==
- &(*this)->getRegion(0)) {
+ if (point == (*this)->getRegion(0)) {
Operation *thisOp = this->getOperation();
regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
}
}
- using RegionBranchOpInterface::Trait<
- SequentialRegionsOp>::getSuccessorRegions;
};
/// A dialect putting all the above together.
More information about the Mlir-commits
mailing list