[Mlir-commits] [mlir] [mlir][Interfaces][NFC] Add `RegionBranchOpInterface` helper for forwarded values (PR #173981)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 30 04:11:57 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Add a helper function to compute a mapping of successor operands to successor inputs. This mapping is computed in various places. Also add a helper function to gather all region branch points.
---
Full diff: https://github.com/llvm/llvm-project/pull/173981.diff
6 Files Affected:
- (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.h (+13)
- (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.td (+22-23)
- (modified) mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp (+22-30)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp (+6-31)
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+13-18)
- (modified) mlir/lib/Interfaces/ControlFlowInterfaces.cpp (+44)
``````````diff
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index bfc24c18429ed..566f4b8fadb5d 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -176,6 +176,19 @@ namespace detail {
LogicalResult verifyTypesAlongControlFlowEdges(Operation *op);
} // namespace detail
+/// A mapping from successor operands to successor inputs.
+///
+/// * A successor operand is an operand of a region branch op or region
+/// branch terminator, that is forwarded to a successor input.
+/// * A successor input is a block argument of a region or a result of the
+/// region branch op, that is populated by a successor operand.
+///
+/// The mapping is 1:N. Each successor operand may be forwarded to multiple
+/// successor inputs. (Because the control flow can dispatch to multiple
+/// possible successors.) Operands that not forwarded at all are not present in
+/// the mapping.
+using RegionBranchSuccessorMapping = DenseMap<OpOperand *, SmallVector<Value>>;
+
/// This class represents a successor of a region. A region successor can either
/// be another region, or the parent operation. If the successor is a region,
/// this class represents the destination region, as well as a set of arguments
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 8760c8b8715f9..2e654ba04ffe5 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -251,30 +251,16 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
"::llvm::SmallVectorImpl<::mlir::RegionBranchPoint> &":$predecessors),
[{}],
/*defaultImplementation=*/[{
- ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
- $_op.getSuccessorRegions(RegionBranchPoint::parent(),
- successors);
- if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
+ auto op = cast<RegionBranchOpInterface>($_op.getOperation());
+ for (::mlir::RegionBranchPoint point : op.getAllRegionBranchPoints()) {
+ ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
+ op.getSuccessorRegions(point, successors);
+ bool isPred = llvm::any_of(successors, [&] (const auto &succ) {
return succ.getSuccessor() == successor.getSuccessor() ||
- (succ.isParent() && successor.isParent());
- }))
- predecessors.push_back(RegionBranchPoint::parent());
- for (Region ®ion : $_op->getRegions()) {
- for (::mlir::Block &block : region) {
- if (block.empty())
- continue;
- if (auto terminator =
- dyn_cast<RegionBranchTerminatorOpInterface>(block.back())) {
- ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
- $_op.getSuccessorRegions(RegionBranchPoint(terminator),
- successors);
- if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
- return succ.getSuccessor() == successor.getSuccessor() ||
- (succ.isParent() && successor.isParent());
- }))
- predecessors.push_back(terminator);
- }
- }
+ (succ.isParent() && successor.isParent());
+ });
+ if (isPred)
+ predecessors.push_back(point);
}
}]>,
InterfaceMethod<[{
@@ -359,6 +345,19 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
/// of the respective terminator.
::mlir::OperandRange getSuccessorOperands(
::mlir::RegionBranchPoint src, ::mlir::RegionSuccessor dest);
+
+ /// Build a mapping from successor operands to successor input. Each
+ /// successor operand could be forwarded to multiple successor inputs.
+ /// Operands that are not forwarded are not added to the map. Unless a
+ /// specific region branch point is specified, this function takes into
+ /// account all possible region branch points.
+ void getSuccessorOperandInputMapping(
+ ::mlir::RegionBranchSuccessorMapping &mapping,
+ std::optional<::mlir::RegionBranchPoint> src = std::nullopt);
+
+ /// Return all possible region branch points: the region branch op itself
+ /// and all region branch terminators.
+ ::llvm::SmallVector<::mlir::RegionBranchPoint> getAllRegionBranchPoints();
}];
}
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 8e63ae86753b4..64adb8cf00ba7 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -591,29 +591,24 @@ LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation(
void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
RegionBranchOpInterface branch,
ArrayRef<AbstractSparseLattice *> operandLattices) {
- Operation *op = branch.getOperation();
- SmallVector<RegionSuccessor> successors;
- SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
- branch.getEntrySuccessorRegions(operands, successors);
-
- // All operands not forwarded to any successor. This set can be non-contiguous
- // in the presence of multiple successors.
- BitVector unaccounted(op->getNumOperands(), true);
-
- for (RegionSuccessor &successor : successors) {
- OperandRange operands = branch.getEntrySuccessorOperands(successor);
- MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
- ValueRange inputs = successor.getSuccessorInputs();
- for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
- meet(getLatticeElement(operand.get()),
- *getLatticeElementFor(getProgramPointAfter(op), input));
- unaccounted.reset(operand.getOperandNumber());
+ // Not all operands are forwarded to a successor. This set can be
+ // non-contiguous in the presence of multiple successors.
+ BitVector unaccounted(branch->getNumOperands(), true);
+
+ RegionBranchSuccessorMapping mapping;
+ branch.getSuccessorOperandInputMapping(mapping, RegionBranchPoint::parent());
+ for (const auto &[operand, inputs] : mapping) {
+ for (Value input : inputs) {
+ meet(getLatticeElement(operand->get()),
+ *getLatticeElementFor(getProgramPointAfter(branch), input));
+ unaccounted.reset(operand->getOperandNumber());
}
}
+
// All operands not forwarded to regions are typically parameters of the
// branch operation itself (for example the boolean for if/else).
for (int index : unaccounted.set_bits()) {
- visitBranchOperand(op->getOpOperand(index));
+ visitBranchOperand(branch->getOpOperand(index));
}
}
@@ -626,24 +621,21 @@ void AbstractSparseBackwardDataFlowAnalysis::
assert(terminator->getParentOp() == branch.getOperation() &&
"expected `branch` to be the parent op of `terminator`");
- SmallVector<Attribute> operandAttributes(terminator->getNumOperands(),
- nullptr);
- SmallVector<RegionSuccessor> successors;
- terminator.getSuccessorRegions(operandAttributes, successors);
- // All operands not forwarded to any successor. This set can be
+ // Not all operands are forwarded to a successor. This set can be
// non-contiguous in the presence of multiple successors.
BitVector unaccounted(terminator->getNumOperands(), true);
- for (const RegionSuccessor &successor : successors) {
- ValueRange inputs = successor.getSuccessorInputs();
- OperandRange operands = terminator.getSuccessorOperands(successor);
- MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
- for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
- meet(getLatticeElement(opOperand.get()),
+ RegionBranchSuccessorMapping mapping;
+ branch.getSuccessorOperandInputMapping(mapping,
+ RegionBranchPoint(terminator));
+ for (const auto &[operand, inputs] : mapping) {
+ for (Value input : inputs) {
+ meet(getLatticeElement(operand->get()),
*getLatticeElementFor(getProgramPointAfter(terminator), input));
- unaccounted.reset(const_cast<OpOperand &>(opOperand).getOperandNumber());
+ unaccounted.reset(operand->getOperandNumber());
}
}
+
// Visit operands of the branch op not forwarded to the next region.
// (Like e.g. the boolean of `scf.conditional`)
for (int index : unaccounted.set_bits()) {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 0b2e080e52b75..ac94ef9d866fa 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -142,37 +142,12 @@ void BufferViewFlowAnalysis::build(Operation *op) {
}
if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) {
- // Query the RegionBranchOpInterface to find potential successor regions.
- // Extract all entry regions and wire all initial entry successor inputs.
- SmallVector<RegionSuccessor, 2> entrySuccessors;
- regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
- entrySuccessors);
- for (RegionSuccessor &entrySuccessor : entrySuccessors) {
- // Wire the entry region's successor arguments with the initial
- // successor inputs.
- registerDependencies(
- regionInterface.getEntrySuccessorOperands(entrySuccessor),
- entrySuccessor.getSuccessorInputs());
- }
-
- // Wire flow between regions and from region exits.
- for (Region ®ion : regionInterface->getRegions()) {
- // Iterate over all successor region entries that are reachable from the
- // current region.
- SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(region, successorRegions);
- for (RegionSuccessor &successorRegion : successorRegions) {
- // Iterate over all immediate terminator operations and wire the
- // successor inputs with the successor operands of each terminator.
- for (Block &block : region)
- if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
- block.getTerminator()))
- registerDependencies(
- terminator.getSuccessorOperands(successorRegion),
- successorRegion.getSuccessorInputs());
- }
- }
-
+ // Wire the successor operands with the successor inputs.
+ DenseMap<OpOperand *, SmallVector<Value>> mapping;
+ regionInterface.getSuccessorOperandInputMapping(mapping);
+ for (const auto &[operand, inputs] : mapping)
+ for (Value input : inputs)
+ registerDependencies({operand->get()}, {input});
return WalkResult::advance();
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 7fc75e7294ea3..6e46fd1570045 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1248,34 +1248,29 @@ updateControlFlowOps(mlir::OpBuilder &builder,
mlir::RegionBranchTerminatorOpInterface terminator,
GetLayoutFnTy getLayoutOfValue) {
// Only process if the terminator is inside a region branch op.
- if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
+ auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
+ if (!branchOp)
return success();
- llvm::SmallVector<mlir::RegionSuccessor> successors;
- llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
- nullptr);
- terminator.getSuccessorRegions(operands, successors);
-
- for (mlir::RegionSuccessor &successor : successors) {
- mlir::OperandRange successorOperands =
- terminator.getSuccessorOperands(successor);
- mlir::ValueRange successorInputs = successor.getSuccessorInputs();
- for (auto [successorOperand, successorInput] :
- llvm::zip(successorOperands, successorInputs)) {
- Type inputType = successorInput.getType();
+ RegionBranchSuccessorMapping mapping;
+ branchOp.getSuccessorOperandInputMapping(mapping,
+ RegionBranchPoint(terminator));
+ for (const auto &[operand, inputs] : mapping) {
+ for (Value input : inputs) {
+ Type inputType = input.getType();
// We only need to operate on tensor descriptor or vector types.
if (!isa<xegpu::TensorDescType, VectorType>(inputType))
continue;
xegpu::DistributeLayoutAttr successorInputLayout =
- getLayoutOfValue(successorInput);
+ getLayoutOfValue(input);
xegpu::DistributeLayoutAttr successorOperandLayout =
- getLayoutOfValue(successorOperand);
+ getLayoutOfValue(operand->get());
// If either of the layouts is not assigned, we cannot proceed.
if (!successorOperandLayout) {
LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
"branch terminator: "
- << successorOperand << "\n");
+ << operand->get() << "\n");
return failure();
}
// We expect the layouts to match.
@@ -1292,12 +1287,12 @@ updateControlFlowOps(mlir::OpBuilder &builder,
auto newTdescTy = xegpu::TensorDescType::get(
tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
tdescTy.getEncoding(), successorOperandLayout);
- successorInput.setType(newTdescTy);
+ input.setType(newTdescTy);
continue;
}
// If the type is a vector type and this region argument is an OpResult,
// set the layout attribute on the OpResult.
- if (auto result = dyn_cast<OpResult>(successorInput))
+ if (auto result = dyn_cast<OpResult>(input))
xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
}
}
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 51861d7751450..d393ddb8d8336 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -489,6 +489,50 @@ RegionBranchOpInterface::getSuccessorOperands(RegionBranchPoint src,
return terminator.getSuccessorOperands(dest);
}
+static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
+ return MutableArrayRef<OpOperand>(operands.getBase(), operands.size());
+}
+
+static void
+getSuccessorOperandInputMapping(RegionBranchOpInterface branchOp,
+ RegionBranchSuccessorMapping &mapping,
+ RegionBranchPoint src) {
+ SmallVector<RegionSuccessor> successors;
+ branchOp.getSuccessorRegions(src, successors);
+ for (RegionSuccessor dst : successors) {
+ OperandRange operands = branchOp.getSuccessorOperands(src, dst);
+ assert(operands.size() == dst.getSuccessorInputs().size() &&
+ "expected the same number of operands and inputs");
+ for (const auto &[operand, input] : llvm::zip_equal(
+ operandsToOpOperands(operands), dst.getSuccessorInputs()))
+ mapping[&operand].push_back(input);
+ }
+}
+void RegionBranchOpInterface::getSuccessorOperandInputMapping(
+ RegionBranchSuccessorMapping &mapping,
+ std::optional<RegionBranchPoint> src) {
+ if (src.has_value()) {
+ ::getSuccessorOperandInputMapping(*this, mapping, src.value());
+ } else {
+ // No region branch point specified: populate the mapping for all possible
+ // region branch points.
+ for (RegionBranchPoint branchPoint : getAllRegionBranchPoints())
+ ::getSuccessorOperandInputMapping(*this, mapping, branchPoint);
+ }
+}
+
+SmallVector<RegionBranchPoint>
+RegionBranchOpInterface::getAllRegionBranchPoints() {
+ SmallVector<RegionBranchPoint> branchPoints;
+ branchPoints.push_back(RegionBranchPoint::parent());
+ for (Region ®ion : getOperation()->getRegions())
+ for (Block &block : region)
+ if (auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
+ branchPoints.push_back(RegionBranchPoint(terminator));
+ return branchPoints;
+}
+
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
LDBG() << "Finding enclosing repetitive region for operation "
<< op->getName();
``````````
</details>
https://github.com/llvm/llvm-project/pull/173981
More information about the Mlir-commits
mailing list