[Mlir-commits] [mlir] 0b24580 - [mlir][Interfaces][NFC] Add `RegionBranchOpInterface` helper for forwarded values (#173981)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 1 02:56:09 PST 2026
Author: Matthias Springer
Date: 2026-01-01T11:56:05+01:00
New Revision: 0b24580a26644c8a44abc6fba97609552ba41d42
URL: https://github.com/llvm/llvm-project/commit/0b24580a26644c8a44abc6fba97609552ba41d42
DIFF: https://github.com/llvm/llvm-project/commit/0b24580a26644c8a44abc6fba97609552ba41d42.diff
LOG: [mlir][Interfaces][NFC] Add `RegionBranchOpInterface` helper for forwarded values (#173981)
Add a helper function to compute a mapping of successor operands to
successor inputs. This mapping is computed in various places. Also add a
helper function to gather all region branch points.
This commit is in preparation of a bug fix / partial redesign of
`-remove-dead-values`. This commit also removes some duplicate code in
various places.
Added:
Modified:
mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
mlir/lib/Interfaces/ControlFlowInterfaces.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index bfc24c18429ed..566f4b8fadb5d 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -176,6 +176,19 @@ namespace detail {
LogicalResult verifyTypesAlongControlFlowEdges(Operation *op);
} // namespace detail
+/// A mapping from successor operands to successor inputs.
+///
+/// * A successor operand is an operand of a region branch op or region
+/// branch terminator, that is forwarded to a successor input.
+/// * A successor input is a block argument of a region or a result of the
+/// region branch op, that is populated by a successor operand.
+///
+/// The mapping is 1:N. Each successor operand may be forwarded to multiple
+/// successor inputs. (Because the control flow can dispatch to multiple
+/// possible successors.) Operands that not forwarded at all are not present in
+/// the mapping.
+using RegionBranchSuccessorMapping = DenseMap<OpOperand *, SmallVector<Value>>;
+
/// This class represents a successor of a region. A region successor can either
/// be another region, or the parent operation. If the successor is a region,
/// this class represents the destination region, as well as a set of arguments
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 8760c8b8715f9..2e654ba04ffe5 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -251,30 +251,16 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
"::llvm::SmallVectorImpl<::mlir::RegionBranchPoint> &":$predecessors),
[{}],
/*defaultImplementation=*/[{
- ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
- $_op.getSuccessorRegions(RegionBranchPoint::parent(),
- successors);
- if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
+ auto op = cast<RegionBranchOpInterface>($_op.getOperation());
+ for (::mlir::RegionBranchPoint point : op.getAllRegionBranchPoints()) {
+ ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
+ op.getSuccessorRegions(point, successors);
+ bool isPred = llvm::any_of(successors, [&] (const auto &succ) {
return succ.getSuccessor() == successor.getSuccessor() ||
- (succ.isParent() && successor.isParent());
- }))
- predecessors.push_back(RegionBranchPoint::parent());
- for (Region ®ion : $_op->getRegions()) {
- for (::mlir::Block &block : region) {
- if (block.empty())
- continue;
- if (auto terminator =
- dyn_cast<RegionBranchTerminatorOpInterface>(block.back())) {
- ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
- $_op.getSuccessorRegions(RegionBranchPoint(terminator),
- successors);
- if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
- return succ.getSuccessor() == successor.getSuccessor() ||
- (succ.isParent() && successor.isParent());
- }))
- predecessors.push_back(terminator);
- }
- }
+ (succ.isParent() && successor.isParent());
+ });
+ if (isPred)
+ predecessors.push_back(point);
}
}]>,
InterfaceMethod<[{
@@ -359,6 +345,19 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
/// of the respective terminator.
::mlir::OperandRange getSuccessorOperands(
::mlir::RegionBranchPoint src, ::mlir::RegionSuccessor dest);
+
+ /// Build a mapping from successor operands to successor input. Each
+ /// successor operand could be forwarded to multiple successor inputs.
+ /// Operands that are not forwarded are not added to the map. Unless a
+ /// specific region branch point is specified, this function takes into
+ /// account all possible region branch points.
+ void getSuccessorOperandInputMapping(
+ ::mlir::RegionBranchSuccessorMapping &mapping,
+ std::optional<::mlir::RegionBranchPoint> src = std::nullopt);
+
+ /// Return all possible region branch points: the region branch op itself
+ /// and all region branch terminators.
+ ::llvm::SmallVector<::mlir::RegionBranchPoint> getAllRegionBranchPoints();
}];
}
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 8e63ae86753b4..64adb8cf00ba7 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -591,29 +591,24 @@ LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation(
void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
RegionBranchOpInterface branch,
ArrayRef<AbstractSparseLattice *> operandLattices) {
- Operation *op = branch.getOperation();
- SmallVector<RegionSuccessor> successors;
- SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
- branch.getEntrySuccessorRegions(operands, successors);
-
- // All operands not forwarded to any successor. This set can be non-contiguous
- // in the presence of multiple successors.
- BitVector unaccounted(op->getNumOperands(), true);
-
- for (RegionSuccessor &successor : successors) {
- OperandRange operands = branch.getEntrySuccessorOperands(successor);
- MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
- ValueRange inputs = successor.getSuccessorInputs();
- for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
- meet(getLatticeElement(operand.get()),
- *getLatticeElementFor(getProgramPointAfter(op), input));
- unaccounted.reset(operand.getOperandNumber());
+ // Not all operands are forwarded to a successor. This set can be
+ // non-contiguous in the presence of multiple successors.
+ BitVector unaccounted(branch->getNumOperands(), true);
+
+ RegionBranchSuccessorMapping mapping;
+ branch.getSuccessorOperandInputMapping(mapping, RegionBranchPoint::parent());
+ for (const auto &[operand, inputs] : mapping) {
+ for (Value input : inputs) {
+ meet(getLatticeElement(operand->get()),
+ *getLatticeElementFor(getProgramPointAfter(branch), input));
+ unaccounted.reset(operand->getOperandNumber());
}
}
+
// All operands not forwarded to regions are typically parameters of the
// branch operation itself (for example the boolean for if/else).
for (int index : unaccounted.set_bits()) {
- visitBranchOperand(op->getOpOperand(index));
+ visitBranchOperand(branch->getOpOperand(index));
}
}
@@ -626,24 +621,21 @@ void AbstractSparseBackwardDataFlowAnalysis::
assert(terminator->getParentOp() == branch.getOperation() &&
"expected `branch` to be the parent op of `terminator`");
- SmallVector<Attribute> operandAttributes(terminator->getNumOperands(),
- nullptr);
- SmallVector<RegionSuccessor> successors;
- terminator.getSuccessorRegions(operandAttributes, successors);
- // All operands not forwarded to any successor. This set can be
+ // Not all operands are forwarded to a successor. This set can be
// non-contiguous in the presence of multiple successors.
BitVector unaccounted(terminator->getNumOperands(), true);
- for (const RegionSuccessor &successor : successors) {
- ValueRange inputs = successor.getSuccessorInputs();
- OperandRange operands = terminator.getSuccessorOperands(successor);
- MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
- for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
- meet(getLatticeElement(opOperand.get()),
+ RegionBranchSuccessorMapping mapping;
+ branch.getSuccessorOperandInputMapping(mapping,
+ RegionBranchPoint(terminator));
+ for (const auto &[operand, inputs] : mapping) {
+ for (Value input : inputs) {
+ meet(getLatticeElement(operand->get()),
*getLatticeElementFor(getProgramPointAfter(terminator), input));
- unaccounted.reset(const_cast<OpOperand &>(opOperand).getOperandNumber());
+ unaccounted.reset(operand->getOperandNumber());
}
}
+
// Visit operands of the branch op not forwarded to the next region.
// (Like e.g. the boolean of `scf.conditional`)
for (int index : unaccounted.set_bits()) {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 0b2e080e52b75..ac94ef9d866fa 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -142,37 +142,12 @@ void BufferViewFlowAnalysis::build(Operation *op) {
}
if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) {
- // Query the RegionBranchOpInterface to find potential successor regions.
- // Extract all entry regions and wire all initial entry successor inputs.
- SmallVector<RegionSuccessor, 2> entrySuccessors;
- regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
- entrySuccessors);
- for (RegionSuccessor &entrySuccessor : entrySuccessors) {
- // Wire the entry region's successor arguments with the initial
- // successor inputs.
- registerDependencies(
- regionInterface.getEntrySuccessorOperands(entrySuccessor),
- entrySuccessor.getSuccessorInputs());
- }
-
- // Wire flow between regions and from region exits.
- for (Region ®ion : regionInterface->getRegions()) {
- // Iterate over all successor region entries that are reachable from the
- // current region.
- SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(region, successorRegions);
- for (RegionSuccessor &successorRegion : successorRegions) {
- // Iterate over all immediate terminator operations and wire the
- // successor inputs with the successor operands of each terminator.
- for (Block &block : region)
- if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
- block.getTerminator()))
- registerDependencies(
- terminator.getSuccessorOperands(successorRegion),
- successorRegion.getSuccessorInputs());
- }
- }
-
+ // Wire the successor operands with the successor inputs.
+ DenseMap<OpOperand *, SmallVector<Value>> mapping;
+ regionInterface.getSuccessorOperandInputMapping(mapping);
+ for (const auto &[operand, inputs] : mapping)
+ for (Value input : inputs)
+ registerDependencies({operand->get()}, {input});
return WalkResult::advance();
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 7fc75e7294ea3..60fb13c7c2cd7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1248,20 +1248,15 @@ updateControlFlowOps(mlir::OpBuilder &builder,
mlir::RegionBranchTerminatorOpInterface terminator,
GetLayoutFnTy getLayoutOfValue) {
// Only process if the terminator is inside a region branch op.
- if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
+ auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
+ if (!branchOp)
return success();
- llvm::SmallVector<mlir::RegionSuccessor> successors;
- llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
- nullptr);
- terminator.getSuccessorRegions(operands, successors);
-
- for (mlir::RegionSuccessor &successor : successors) {
- mlir::OperandRange successorOperands =
- terminator.getSuccessorOperands(successor);
- mlir::ValueRange successorInputs = successor.getSuccessorInputs();
- for (auto [successorOperand, successorInput] :
- llvm::zip(successorOperands, successorInputs)) {
+ RegionBranchSuccessorMapping mapping;
+ branchOp.getSuccessorOperandInputMapping(mapping,
+ RegionBranchPoint(terminator));
+ for (const auto &[successorOperand, successorInputs] : mapping) {
+ for (Value successorInput : successorInputs) {
Type inputType = successorInput.getType();
// We only need to operate on tensor descriptor or vector types.
if (!isa<xegpu::TensorDescType, VectorType>(inputType))
@@ -1269,13 +1264,13 @@ updateControlFlowOps(mlir::OpBuilder &builder,
xegpu::DistributeLayoutAttr successorInputLayout =
getLayoutOfValue(successorInput);
xegpu::DistributeLayoutAttr successorOperandLayout =
- getLayoutOfValue(successorOperand);
+ getLayoutOfValue(successorOperand->get());
// If either of the layouts is not assigned, we cannot proceed.
if (!successorOperandLayout) {
LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
"branch terminator: "
- << successorOperand << "\n");
+ << successorOperand->get() << "\n");
return failure();
}
// We expect the layouts to match.
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 51861d7751450..d393ddb8d8336 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -489,6 +489,50 @@ RegionBranchOpInterface::getSuccessorOperands(RegionBranchPoint src,
return terminator.getSuccessorOperands(dest);
}
+static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
+ return MutableArrayRef<OpOperand>(operands.getBase(), operands.size());
+}
+
+static void
+getSuccessorOperandInputMapping(RegionBranchOpInterface branchOp,
+ RegionBranchSuccessorMapping &mapping,
+ RegionBranchPoint src) {
+ SmallVector<RegionSuccessor> successors;
+ branchOp.getSuccessorRegions(src, successors);
+ for (RegionSuccessor dst : successors) {
+ OperandRange operands = branchOp.getSuccessorOperands(src, dst);
+ assert(operands.size() == dst.getSuccessorInputs().size() &&
+ "expected the same number of operands and inputs");
+ for (const auto &[operand, input] : llvm::zip_equal(
+ operandsToOpOperands(operands), dst.getSuccessorInputs()))
+ mapping[&operand].push_back(input);
+ }
+}
+void RegionBranchOpInterface::getSuccessorOperandInputMapping(
+ RegionBranchSuccessorMapping &mapping,
+ std::optional<RegionBranchPoint> src) {
+ if (src.has_value()) {
+ ::getSuccessorOperandInputMapping(*this, mapping, src.value());
+ } else {
+ // No region branch point specified: populate the mapping for all possible
+ // region branch points.
+ for (RegionBranchPoint branchPoint : getAllRegionBranchPoints())
+ ::getSuccessorOperandInputMapping(*this, mapping, branchPoint);
+ }
+}
+
+SmallVector<RegionBranchPoint>
+RegionBranchOpInterface::getAllRegionBranchPoints() {
+ SmallVector<RegionBranchPoint> branchPoints;
+ branchPoints.push_back(RegionBranchPoint::parent());
+ for (Region ®ion : getOperation()->getRegions())
+ for (Block &block : region)
+ if (auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
+ branchPoints.push_back(RegionBranchPoint(terminator));
+ return branchPoints;
+}
+
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
LDBG() << "Finding enclosing repetitive region for operation "
<< op->getName();
More information about the Mlir-commits
mailing list