[Mlir-commits] [mlir] [mlir][ControlFlow] Improve time complexity of RegionBranchOpInterface canonicalization patterns (PR #186114)

Yang Bai llvmlistbot at llvm.org
Thu Mar 12 06:00:06 PDT 2026


https://github.com/yangtetris created https://github.com/llvm/llvm-project/pull/186114

Optimize two canonicalization patterns in `ControlFlowInterfaces.cpp`:

1. **RemoveDuplicateSuccessorInputUses**: Replace O(n² * k) pairwise comparison of successor inputs with O(n * k * max(log k, log n)) signature-based grouping using `std::map`, where _n_ is the number of successor inputs and _k_ is the number of predecessors per input.

2. **MakeRegionBranchOpSuccessorInputsDead**: Add early exit to `computeReachableValuesFromSuccessorInput` when the caller only needs to know if there is exactly one reachable value, avoiding unnecessary traversal.

>From 48f676182713fbf4e47733ef76aad4dfc8495f73 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Thu, 12 Mar 2026 03:51:43 -0700
Subject: [PATCH] [mlir] Improve time complexity of RegionBranchOpInterface
 canonicalization patterns
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Optimize two canonicalization patterns in ControlFlowInterfaces.cpp:

1. RemoveDuplicateSuccessorInputUses: Replace O(n² * k) pairwise
   comparison of successor inputs with O(n * k * max(log k, log n))
   signature-based grouping using std::map, where n is the number of
   successor inputs and k is the number of predecessors per input.

2. MakeRegionBranchOpSuccessorInputsDead: Add early exit to
   computeReachableValuesFromSuccessorInput when the caller only needs
   to know if there is exactly one reachable value, avoiding unnecessary
   traversal.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply at anthropic.com>
---
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 124 +++++++++++-------
 1 file changed, 73 insertions(+), 51 deletions(-)

diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 2f95531455b2b..c696cd6785f1d 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include <map>
 #include <utility>
 
 #include "mlir/IR/BuiltinTypes.h"
@@ -630,6 +631,15 @@ static bool isDefinedBefore(Operation *regionBranchOp, Value a, Value b) {
 /// Compute all non-successor-input values that a successor input could have
 /// based on the given successor input to successor operand mapping.
 ///
+/// Starting with the given value, trace back all predecessor values (i.e.,
+/// preceding successor operands) and add them to the set of reachable values.
+/// If the successor operand is again a successor input, do not add it to the
+/// result set, but instead continue the traversal.
+///
+/// If `maxReachableValues` is non-zero, the traversal is aborted early as soon
+/// as the number of reachable values exceeds the limit. This is useful when
+/// the caller only cares whether there is exactly one reachable value.
+///
 /// Example 1:
 /// %r = scf.if ... {
 ///   scf.yield %a : ...
@@ -653,12 +663,9 @@ static bool isDefinedBefore(Operation *regionBranchOp, Value a, Value b) {
 /// reachableValues(%arg0) = {%0, %1}
 /// reachableValues(%r) = {%0, %1}
 static llvm::SmallDenseSet<Value> computeReachableValuesFromSuccessorInput(
-    Value value, const RegionBranchInverseSuccessorMapping &inputToOperands) {
+    Value value, const RegionBranchInverseSuccessorMapping &inputToOperands,
+    unsigned maxReachableValues = 0) {
   assert(inputToOperands.contains(value) && "value must be a successor input");
-  // Starting with the given value, trace back all predecessor values (i.e.,
-  // preceding successor operands) and add them to the set of reachable values.
-  // If the successor operand is again a successor input, do not add it to
-  // result set, but instead continue the traversal.
   llvm::SmallDenseSet<Value> reachableValues;
   llvm::SmallDenseSet<Value> visited;
   SmallVector<Value> worklist;
@@ -668,6 +675,10 @@ static llvm::SmallDenseSet<Value> computeReachableValuesFromSuccessorInput(
     auto it = inputToOperands.find(next);
     if (it == inputToOperands.end()) {
       reachableValues.insert(next);
+      // Early exit: stop traversal if more reachable values than the caller
+      // cares about have been found.
+      if (maxReachableValues > 0 && reachableValues.size() > maxReachableValues)
+        return reachableValues;
       continue;
     }
     for (OpOperand *operand : it->second)
@@ -729,7 +740,8 @@ struct MakeRegionBranchOpSuccessorInputsDead : public RewritePattern {
       // Nothing to do for successor inputs that may have multiple reachable
       // values.
       llvm::SmallDenseSet<Value> reachableValues =
-          computeReachableValuesFromSuccessorInput(value, inputToOperands);
+          computeReachableValuesFromSuccessorInput(value, inputToOperands,
+                                                   /*maxReachableValues=*/1);
       if (reachableValues.size() != 1)
         continue;
       assert(*reachableValues.begin() != value &&
@@ -930,20 +942,6 @@ struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern {
   }
 };
 
-/// Return "true" if the two values are owned by the same operation or block.
-static bool haveSameOwner(Value a, Value b) {
-  void *aOwner, *bOwner;
-  if (auto arg = dyn_cast<BlockArgument>(a))
-    aOwner = arg.getOwner();
-  else
-    aOwner = a.getDefiningOp();
-  if (auto arg = dyn_cast<BlockArgument>(b))
-    bOwner = arg.getOwner();
-  else
-    bOwner = b.getDefiningOp();
-  return aOwner == bOwner;
-}
-
 /// Get the block argument or op result number of the given value.
 static unsigned getArgOrResultNumber(Value value) {
   if (auto opResult = llvm::dyn_cast<OpResult>(value))
@@ -1006,39 +1004,63 @@ struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
       return getArgOrResultNumber(a) < getArgOrResultNumber(b);
     });
 
-    // Check every distinct pair of successor inputs for duplicates. Replace
-    // `input2` with `input1` if they are duplicates.
+    // Group inputs by their operand "signature" to find duplicates. Two
+    // successor inputs are duplicates if each predecessor (region branch point)
+    // forwards the same value for both. Let n = number of successor inputs and
+    // k = number of predecessors per input. Instead of comparing every pair of
+    // inputs (O(n² * k)), we build a signature for each input and group them
+    // via a std::map.
+    //
+    // A signature is a sorted list of (predecessor, forwarded value) pairs.
+    // Within each group, all but the first (canonical) input are replaced with
+    // the canonical one.
+    using SigEntry = std::pair<Operation *, Value>;
+    using Signature = SmallVector<SigEntry>;
+    auto sigEntryLess = [](const SigEntry &a, const SigEntry &b) {
+      if (a.first != b.first)
+        return a.first < b.first;
+      return a.second.getAsOpaquePointer() < b.second.getAsOpaquePointer();
+    };
+    // The map key is (signature, owner). Two inputs are duplicates only if they
+    // have the same signature AND the same owner (block or defining op). This
+    // ensures we track one canonical per owner group.
+    using MapKey = std::pair<Signature, void *>;
+    auto mapKeyLess = [&](const MapKey &a, const MapKey &b) {
+      if (a.second != b.second)
+        return a.second < b.second;
+      return std::lexicographical_compare(a.first.begin(), a.first.end(),
+                                          b.first.begin(), b.first.end(),
+                                          sigEntryLess);
+    };
+    std::map<MapKey, Value, decltype(mapKeyLess)> signatureToCanonical(
+        mapKeyLess);
     bool changed = false;
-    unsigned numInputs = inputs.size();
-    for (auto i : llvm::seq<unsigned>(0, numInputs)) {
-      Value input1 = inputs[i];
-      for (auto j : llvm::seq<unsigned>(i + 1, numInputs)) {
-        Value input2 = inputs[j];
-        // Nothing to do if input2 is already dead.
-        if (input2.use_empty())
-          continue;
-        // Replace only values that belong to the same block / operation.
-        // This implies that the two values are either both block arguments or
-        // both op results.
-        if (!haveSameOwner(input1, input2))
+    // Total complexity: O(n * k * max(log k, log n)). For each input, sorting
+    // the signature costs O(k log k) and the std::map lookup costs O(k log n).
+    for (Value input : inputs) {
+      // Gather the predecessor value for each predecessor (region branch
+      // point) and sort them to form this input's signature.
+      Signature sig;
+      for (OpOperand *operand : inputsToOperands[input])
+        sig.emplace_back(operand->getOwner(), operand->get());
+      llvm::sort(sig, sigEntryLess);
+
+      // Determine the owner (block for block args, defining op for results).
+      void *owner;
+      if (auto arg = dyn_cast<BlockArgument>(input))
+        owner = arg.getOwner();
+      else
+        owner = input.getDefiningOp();
+
+      auto [it, inserted] = signatureToCanonical.try_emplace(
+          MapKey{std::move(sig), owner}, input);
+      if (!inserted) {
+        Value canonical = it->second;
+        // Nothing to do if input is already dead.
+        if (input.use_empty())
           continue;
-
-        // Gather the predecessor value for each predecessor (region branch
-        // point). The two inputs are duplicates if each predecessor forwards
-        // the same value.
-        llvm::SmallDenseMap<Operation *, Value> operands1, operands2;
-        for (OpOperand *operand : inputsToOperands[input1]) {
-          assert(!operands1.contains(operand->getOwner()));
-          operands1[operand->getOwner()] = operand->get();
-        }
-        for (OpOperand *operand : inputsToOperands[input2]) {
-          assert(!operands2.contains(operand->getOwner()));
-          operands2[operand->getOwner()] = operand->get();
-        }
-        if (operands1 == operands2) {
-          rewriter.replaceAllUsesWith(input2, input1);
-          changed = true;
-        }
+        rewriter.replaceAllUsesWith(input, canonical);
+        changed = true;
       }
     }
     return success(changed);



More information about the Mlir-commits mailing list