[Mlir-commits] [mlir] [mlir][ControlFlow] Improve time complexity of RegionBranchOpInterface canonicalization patterns (PR #186114)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 12 06:00:41 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Yang Bai (yangtetris)
<details>
<summary>Changes</summary>
Optimize two canonicalization patterns in `ControlFlowInterfaces.cpp`:
1. **RemoveDuplicateSuccessorInputUses**: Replace O(n² * k) pairwise comparison of successor inputs with O(n * k * max(log k, log n)) signature-based grouping using `std::map`, where _n_ is the number of successor inputs and _k_ is the number of predecessors per input.
2. **MakeRegionBranchOpSuccessorInputsDead**: Add early exit to `computeReachableValuesFromSuccessorInput` when the caller only needs to know if there is exactly one reachable value, avoiding unnecessary traversal.
---
Full diff: https://github.com/llvm/llvm-project/pull/186114.diff
1 Files Affected:
- (modified) mlir/lib/Interfaces/ControlFlowInterfaces.cpp (+73-51)
``````````diff
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 2f95531455b2b..c696cd6785f1d 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include <map>
#include <utility>
#include "mlir/IR/BuiltinTypes.h"
@@ -630,6 +631,15 @@ static bool isDefinedBefore(Operation *regionBranchOp, Value a, Value b) {
/// Compute all non-successor-input values that a successor input could have
/// based on the given successor input to successor operand mapping.
///
+/// Starting with the given value, trace back all predecessor values (i.e.,
+/// preceding successor operands) and add them to the set of reachable values.
+/// If the successor operand is again a successor input, do not add it to the
+/// result set, but instead continue the traversal.
+///
+/// If `maxReachableValues` is non-zero, the traversal is aborted early as soon
+/// as the number of reachable values exceeds the limit. This is useful when
+/// the caller only cares whether there is exactly one reachable value.
+///
/// Example 1:
/// %r = scf.if ... {
/// scf.yield %a : ...
@@ -653,12 +663,9 @@ static bool isDefinedBefore(Operation *regionBranchOp, Value a, Value b) {
/// reachableValues(%arg0) = {%0, %1}
/// reachableValues(%r) = {%0, %1}
static llvm::SmallDenseSet<Value> computeReachableValuesFromSuccessorInput(
- Value value, const RegionBranchInverseSuccessorMapping &inputToOperands) {
+ Value value, const RegionBranchInverseSuccessorMapping &inputToOperands,
+ unsigned maxReachableValues = 0) {
assert(inputToOperands.contains(value) && "value must be a successor input");
- // Starting with the given value, trace back all predecessor values (i.e.,
- // preceding successor operands) and add them to the set of reachable values.
- // If the successor operand is again a successor input, do not add it to
- // result set, but instead continue the traversal.
llvm::SmallDenseSet<Value> reachableValues;
llvm::SmallDenseSet<Value> visited;
SmallVector<Value> worklist;
@@ -668,6 +675,10 @@ static llvm::SmallDenseSet<Value> computeReachableValuesFromSuccessorInput(
auto it = inputToOperands.find(next);
if (it == inputToOperands.end()) {
reachableValues.insert(next);
+ // Early exit: stop traversal if more reachable values than the caller
+ // cares about have been found.
+ if (maxReachableValues > 0 && reachableValues.size() > maxReachableValues)
+ return reachableValues;
continue;
}
for (OpOperand *operand : it->second)
@@ -729,7 +740,8 @@ struct MakeRegionBranchOpSuccessorInputsDead : public RewritePattern {
// Nothing to do for successor inputs that may have multiple reachable
// values.
llvm::SmallDenseSet<Value> reachableValues =
- computeReachableValuesFromSuccessorInput(value, inputToOperands);
+ computeReachableValuesFromSuccessorInput(value, inputToOperands,
+ /*maxReachableValues=*/1);
if (reachableValues.size() != 1)
continue;
assert(*reachableValues.begin() != value &&
@@ -930,20 +942,6 @@ struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern {
}
};
-/// Return "true" if the two values are owned by the same operation or block.
-static bool haveSameOwner(Value a, Value b) {
- void *aOwner, *bOwner;
- if (auto arg = dyn_cast<BlockArgument>(a))
- aOwner = arg.getOwner();
- else
- aOwner = a.getDefiningOp();
- if (auto arg = dyn_cast<BlockArgument>(b))
- bOwner = arg.getOwner();
- else
- bOwner = b.getDefiningOp();
- return aOwner == bOwner;
-}
-
/// Get the block argument or op result number of the given value.
static unsigned getArgOrResultNumber(Value value) {
if (auto opResult = llvm::dyn_cast<OpResult>(value))
@@ -1006,39 +1004,63 @@ struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
return getArgOrResultNumber(a) < getArgOrResultNumber(b);
});
- // Check every distinct pair of successor inputs for duplicates. Replace
- // `input2` with `input1` if they are duplicates.
+ // Group inputs by their operand "signature" to find duplicates. Two
+ // successor inputs are duplicates if each predecessor (region branch point)
+ // forwards the same value for both. Let n = number of successor inputs and
+ // k = number of predecessors per input. Instead of comparing every pair of
+ // inputs (O(n² * k)), we build a signature for each input and group them
+ // via a std::map.
+ //
+ // A signature is a sorted list of (predecessor, forwarded value) pairs.
+ // Within each group, all but the first (canonical) input are replaced with
+ // the canonical one.
+ using SigEntry = std::pair<Operation *, Value>;
+ using Signature = SmallVector<SigEntry>;
+ auto sigEntryLess = [](const SigEntry &a, const SigEntry &b) {
+ if (a.first != b.first)
+ return a.first < b.first;
+ return a.second.getAsOpaquePointer() < b.second.getAsOpaquePointer();
+ };
+ // The map key is (signature, owner). Two inputs are duplicates only if they
+ // have the same signature AND the same owner (block or defining op). This
+ // ensures we track one canonical per owner group.
+ using MapKey = std::pair<Signature, void *>;
+ auto mapKeyLess = [&](const MapKey &a, const MapKey &b) {
+ if (a.second != b.second)
+ return a.second < b.second;
+ return std::lexicographical_compare(a.first.begin(), a.first.end(),
+ b.first.begin(), b.first.end(),
+ sigEntryLess);
+ };
+ std::map<MapKey, Value, decltype(mapKeyLess)> signatureToCanonical(
+ mapKeyLess);
bool changed = false;
- unsigned numInputs = inputs.size();
- for (auto i : llvm::seq<unsigned>(0, numInputs)) {
- Value input1 = inputs[i];
- for (auto j : llvm::seq<unsigned>(i + 1, numInputs)) {
- Value input2 = inputs[j];
- // Nothing to do if input2 is already dead.
- if (input2.use_empty())
- continue;
- // Replace only values that belong to the same block / operation.
- // This implies that the two values are either both block arguments or
- // both op results.
- if (!haveSameOwner(input1, input2))
+ // Total complexity: O(n * k * max(log k, log n)). For each input, sorting
+ // the signature costs O(k log k) and the std::map lookup costs O(k log n).
+ for (Value input : inputs) {
+ // Gather the predecessor value for each predecessor (region branch
+ // point) and sort them to form this input's signature.
+ Signature sig;
+ for (OpOperand *operand : inputsToOperands[input])
+ sig.emplace_back(operand->getOwner(), operand->get());
+ llvm::sort(sig, sigEntryLess);
+
+ // Determine the owner (block for block args, defining op for results).
+ void *owner;
+ if (auto arg = dyn_cast<BlockArgument>(input))
+ owner = arg.getOwner();
+ else
+ owner = input.getDefiningOp();
+
+ auto [it, inserted] = signatureToCanonical.try_emplace(
+ MapKey{std::move(sig), owner}, input);
+ if (!inserted) {
+ Value canonical = it->second;
+ // Nothing to do if input is already dead.
+ if (input.use_empty())
continue;
-
- // Gather the predecessor value for each predecessor (region branch
- // point). The two inputs are duplicates if each predecessor forwards
- // the same value.
- llvm::SmallDenseMap<Operation *, Value> operands1, operands2;
- for (OpOperand *operand : inputsToOperands[input1]) {
- assert(!operands1.contains(operand->getOwner()));
- operands1[operand->getOwner()] = operand->get();
- }
- for (OpOperand *operand : inputsToOperands[input2]) {
- assert(!operands2.contains(operand->getOwner()));
- operands2[operand->getOwner()] = operand->get();
- }
- if (operands1 == operands2) {
- rewriter.replaceAllUsesWith(input2, input1);
- changed = true;
- }
+ rewriter.replaceAllUsesWith(input, canonical);
+ changed = true;
}
}
return success(changed);
``````````
</details>
https://github.com/llvm/llvm-project/pull/186114
More information about the Mlir-commits
mailing list