[Mlir-commits] [mlir] [mlir][ControlFlow] Improve time complexity of RegionBranchOpInterface canonicalization patterns (PR #186114)
Matthias Springer
llvmlistbot at llvm.org
Fri Mar 13 03:01:55 PDT 2026
================
@@ -1006,39 +993,63 @@ struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
return getArgOrResultNumber(a) < getArgOrResultNumber(b);
});
- // Check every distinct pair of successor inputs for duplicates. Replace
- // `input2` with `input1` if they are duplicates.
+ // Group inputs by their operand "signature" to find duplicates. Two
+ // successor inputs are duplicates if each predecessor (region branch point)
+ // forwards the same value for both. Let n = number of successor inputs and
+ // k = number of predecessors per input. Instead of comparing every pair of
+ // inputs (O(n² * k)), we build a signature for each input and group them
+ // via a std::map.
+ //
+ // A signature is a sorted list of (predecessor, forwarded value) pairs.
+ // Within each group, all but the first (canonical) input are replaced with
+ // the canonical one.
+ using SigEntry = std::pair<Operation *, Value>;
+ using Signature = SmallVector<SigEntry>;
+ auto sigEntryLess = [](const SigEntry &a, const SigEntry &b) {
+ if (a.first != b.first)
+ return a.first < b.first;
+ return a.second.getAsOpaquePointer() < b.second.getAsOpaquePointer();
+ };
+ // The map key is (signature, owner). Two inputs are duplicates only if they
+ // have the same signature AND the same owner (block or defining op). This
+ // ensures we track one canonical per owner group.
+ using MapKey = std::pair<Signature, void *>;
+ auto mapKeyLess = [&](const MapKey &a, const MapKey &b) {
+ if (a.second != b.second)
+ return a.second < b.second;
+ return std::lexicographical_compare(a.first.begin(), a.first.end(),
+ b.first.begin(), b.first.end(),
+ sigEntryLess);
+ };
+ std::map<MapKey, Value, decltype(mapKeyLess)> signatureToCanonical(
+ mapKeyLess);
bool changed = false;
- unsigned numInputs = inputs.size();
- for (auto i : llvm::seq<unsigned>(0, numInputs)) {
- Value input1 = inputs[i];
- for (auto j : llvm::seq<unsigned>(i + 1, numInputs)) {
- Value input2 = inputs[j];
- // Nothing to do if input2 is already dead.
- if (input2.use_empty())
+ // Total complexity: O(n * k * max(log k, log n)). For each input, sorting
+ // the signature costs O(k log k) and the std::map lookup costs O(k log n).
+ for (Value input : inputs) {
+ // Gather the predecessor value for each predecessor (region branch
+ // point) and sort them to form this input's signature.
+ Signature sig;
+ for (OpOperand *operand : inputsToOperands[input])
+ sig.emplace_back(operand->getOwner(), operand->get());
+ llvm::sort(sig, sigEntryLess);
+
+ // Determine the owner (block for block args, defining op for results).
----------------
matthias-springer wrote:
Can move this into a new function:
```
static void *getOwnerOfValue(Value v) { ... }
```
We have this in a few places and I'd like to add a new API to `class Value` at some point.
https://github.com/llvm/llvm-project/pull/186114
More information about the Mlir-commits
mailing list