[Mlir-commits] [mlir] [mlir][ControlFlow] Improve time complexity of RegionBranchOpInterface canonicalization patterns (PR #186114)
Yang Bai
llvmlistbot at llvm.org
Fri Mar 13 05:35:56 PDT 2026
https://github.com/yangtetris updated https://github.com/llvm/llvm-project/pull/186114
>From c1afb0abecae5b7d0e42bcd89c89e4213a364f43 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Thu, 12 Mar 2026 03:51:43 -0700
Subject: [PATCH 1/2] [mlir] Improve time complexity of
RemoveDuplicateSuccessorInputUses
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Replace O(n² * k) pairwise comparison of successor inputs with
O(n * k * max(log k, log n)) signature-based grouping using std::map,
where n is the number of successor inputs and k is the number of
predecessors per input.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply at anthropic.com>
---
mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 101 ++++++++++--------
1 file changed, 56 insertions(+), 45 deletions(-)
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 2f95531455b2b..1a0236122cf4d 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include <map>
#include <utility>
#include "mlir/IR/BuiltinTypes.h"
@@ -930,20 +931,6 @@ struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern {
}
};
-/// Return "true" if the two values are owned by the same operation or block.
-static bool haveSameOwner(Value a, Value b) {
- void *aOwner, *bOwner;
- if (auto arg = dyn_cast<BlockArgument>(a))
- aOwner = arg.getOwner();
- else
- aOwner = a.getDefiningOp();
- if (auto arg = dyn_cast<BlockArgument>(b))
- bOwner = arg.getOwner();
- else
- bOwner = b.getDefiningOp();
- return aOwner == bOwner;
-}
-
/// Get the block argument or op result number of the given value.
static unsigned getArgOrResultNumber(Value value) {
if (auto opResult = llvm::dyn_cast<OpResult>(value))
@@ -1006,39 +993,63 @@ struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
return getArgOrResultNumber(a) < getArgOrResultNumber(b);
});
- // Check every distinct pair of successor inputs for duplicates. Replace
- // `input2` with `input1` if they are duplicates.
+ // Group inputs by their operand "signature" to find duplicates. Two
+ // successor inputs are duplicates if each predecessor (region branch point)
+ // forwards the same value for both. Let n = number of successor inputs and
+ // k = number of predecessors per input. Instead of comparing every pair of
+ // inputs (O(n² * k)), we build a signature for each input and group them
+ // via a std::map.
+ //
+ // A signature is a sorted list of (predecessor, forwarded value) pairs.
+ // Within each group, all but the first (canonical) input are replaced with
+ // the canonical one.
+ using SigEntry = std::pair<Operation *, Value>;
+ using Signature = SmallVector<SigEntry>;
+ auto sigEntryLess = [](const SigEntry &a, const SigEntry &b) {
+ if (a.first != b.first)
+ return a.first < b.first;
+ return a.second.getAsOpaquePointer() < b.second.getAsOpaquePointer();
+ };
+ // The map key is (signature, owner). Two inputs are duplicates only if they
+ // have the same signature AND the same owner (block or defining op). This
+ // ensures we track one canonical per owner group.
+ using MapKey = std::pair<Signature, void *>;
+ auto mapKeyLess = [&](const MapKey &a, const MapKey &b) {
+ if (a.second != b.second)
+ return a.second < b.second;
+ return std::lexicographical_compare(a.first.begin(), a.first.end(),
+ b.first.begin(), b.first.end(),
+ sigEntryLess);
+ };
+ std::map<MapKey, Value, decltype(mapKeyLess)> signatureToCanonical(
+ mapKeyLess);
bool changed = false;
- unsigned numInputs = inputs.size();
- for (auto i : llvm::seq<unsigned>(0, numInputs)) {
- Value input1 = inputs[i];
- for (auto j : llvm::seq<unsigned>(i + 1, numInputs)) {
- Value input2 = inputs[j];
- // Nothing to do if input2 is already dead.
- if (input2.use_empty())
+ // Total complexity: O(n * k * max(log k, log n)). For each input, sorting
+ // the signature costs O(k log k) and the std::map lookup costs O(k log n).
+ for (Value input : inputs) {
+ // Gather the predecessor value for each predecessor (region branch
+ // point) and sort them to form this input's signature.
+ Signature sig;
+ for (OpOperand *operand : inputsToOperands[input])
+ sig.emplace_back(operand->getOwner(), operand->get());
+ llvm::sort(sig, sigEntryLess);
+
+ // Determine the owner (block for block args, defining op for results).
+ void *owner;
+ if (auto arg = dyn_cast<BlockArgument>(input))
+ owner = arg.getOwner();
+ else
+ owner = input.getDefiningOp();
+
+ auto [it, inserted] = signatureToCanonical.try_emplace(
+ MapKey{std::move(sig), owner}, input);
+ if (!inserted) {
+ Value canonical = it->second;
+ // Nothing to do if input is already dead.
+ if (input.use_empty())
continue;
- // Replace only values that belong to the same block / operation.
- // This implies that the two values are either both block arguments or
- // both op results.
- if (!haveSameOwner(input1, input2))
- continue;
-
- // Gather the predecessor value for each predecessor (region branch
- // point). The two inputs are duplicates if each predecessor forwards
- // the same value.
- llvm::SmallDenseMap<Operation *, Value> operands1, operands2;
- for (OpOperand *operand : inputsToOperands[input1]) {
- assert(!operands1.contains(operand->getOwner()));
- operands1[operand->getOwner()] = operand->get();
- }
- for (OpOperand *operand : inputsToOperands[input2]) {
- assert(!operands2.contains(operand->getOwner()));
- operands2[operand->getOwner()] = operand->get();
- }
- if (operands1 == operands2) {
- rewriter.replaceAllUsesWith(input2, input1);
- changed = true;
- }
+ rewriter.replaceAllUsesWith(input, canonical);
+ changed = true;
}
}
return success(changed);
>From 090a091adceae1e26889370cf259ed9deb9bab6f Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Fri, 13 Mar 2026 03:36:06 -0700
Subject: [PATCH 2/2] [mlir] NFC: Extract getOwnerOfValue helper function
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply at anthropic.com>
---
mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 15 +++++++++------
1 file changed, 9 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 1a0236122cf4d..8464b633a2625 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -931,6 +931,14 @@ struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern {
}
};
+/// Return the "owner" of a value: the parent block for block arguments, the
+/// defining op for op results.
+static void *getOwnerOfValue(Value value) {
+ if (auto arg = dyn_cast<BlockArgument>(value))
+ return arg.getOwner();
+ return value.getDefiningOp();
+}
+
/// Get the block argument or op result number of the given value.
static unsigned getArgOrResultNumber(Value value) {
if (auto opResult = llvm::dyn_cast<OpResult>(value))
@@ -1034,12 +1042,7 @@ struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
sig.emplace_back(operand->getOwner(), operand->get());
llvm::sort(sig, sigEntryLess);
- // Determine the owner (block for block args, defining op for results).
- void *owner;
- if (auto arg = dyn_cast<BlockArgument>(input))
- owner = arg.getOwner();
- else
- owner = input.getDefiningOp();
+ void *owner = getOwnerOfValue(input);
auto [it, inserted] = signatureToCanonical.try_emplace(
MapKey{std::move(sig), owner}, input);
More information about the Mlir-commits
mailing list