[Mlir-commits] [mlir] [mlir] Consolidate patterns into `RegionBranchOpInterface` patterns (PR #174094)

Ivan Butygin llvmlistbot at llvm.org
Sat Jan 10 09:42:48 PST 2026


================
@@ -540,3 +559,480 @@ Region *mlir::getEnclosingRepetitiveRegion(Value value) {
   LDBG() << "No enclosing repetitive region found for value";
   return nullptr;
 }
+
+/// Return "true" if `a` can be used in lieu of `b`, where `b` is a region
+/// successor input and `a` is a "reachable value" of `b`. Reachable values
+/// are successor operand values that are (maybe transitively) forwarded to
+/// `b`.
+static bool isDefinedBefore(Operation *regionBranchOp, Value a, Value b) {
+  assert((b.getDefiningOp() == regionBranchOp ||
+          b.getParentRegion()->getParentOp() == regionBranchOp) &&
+         "b must be a region successor input");
+
+  // Case 1: `a` is defined inside of the region branch op. `a` must be
+  // directly nested in the region branch op. Otherwise, it could not have
+  // been among the reachable values for a region successor input.
+  if (a.getParentRegion()->getParentOp() == regionBranchOp) {
+    // Case 1.1: If `b` is a result of the region branch op, `a` is not in
+    // scope for `b`.
+    // Example:
+    // %b = region_op({
+    // ^bb0(%a1: ...):
+    //   %a2 = ...
+    // })
+    if (isa<OpResult>(b))
+      return false;
+
+    // Case 1.2: `b` is an entry block argument of a region. `a` is in scope
+    // for `b` only if it is also an entry block argument of the same region.
+    // Example:
+    // region_op({
+    // ^bb0(%b: ..., %a: ...):
+    //   ...
+    // })
+    assert(isa<BlockArgument>(b) && "b must be a block argument");
+    return isa<BlockArgument>(a) && cast<BlockArgument>(a).getOwner() ==
+                                        cast<BlockArgument>(b).getOwner();
+  }
+
+  // Case 2: `a` is defined outside of the region branch op. In that case, we
+  // can safely assume that `a` was defined before `b`. Otherwise, it could not
+  // be among the reachable values for a region successor input.
+  // Example:
+  // {   <- %a1 parent region begins here.
+  // ^bb0(%a1: ...):
+  //   %a2 = ...
+  //   %b1 = reigon_op({
+  //   ^bb1(%b2: ...):
+  //     ...
+  //   })
+  // }
+  return true;
+}
+
+/// Compute all non-successor-input values that a successor input could have
+/// based on the given successor input to successor operand mapping.
+///
+/// Example 1:
+/// %r = scf.if ... {
+///   scf.yield %a : ...
+/// } else {
+///   scf.yield %b : ...
+/// }
+/// reachableValues(%r) = {%a, %b}
+///
+/// Example 2:
+/// %r = scf.for ... iter_args(%arg0 = %0) -> ... {
+///   scf.yield %arg0 : ...
+/// }
+/// reachableValues(%arg0) = {%0}
+/// reachableValues(%r) = {%0}
+///
+/// Example 3:
+/// %r = scf.for ... iter_args(%arg0 = %0) -> ... {
+///   ...
+///   scf.yield %1 : ...
+/// }
+/// reachableValues(%arg0) = {%0, %1}
+/// reachableValues(%r) = {%0, %1}
+static llvm::SmallDenseSet<Value> computeReachableValuesFromSuccessorInput(
+    Value value, const RegionBranchInverseSuccessorMapping &inputToOperands) {
+  assert(inputToOperands.contains(value) && "value must be a successor input");
+  // Starting with the given value, trace back all predecessor values (i.e.,
+  // preceding successor operands) and add them to the set of reachable values.
+  // If the successor operand is again a successor input, do not add it to
+  // result set, but instead continue the traversal.
+  llvm::SmallDenseSet<Value> reachableValues;
+  llvm::SmallDenseSet<Value> visited;
+  SmallVector<Value> worklist;
+  worklist.push_back(value);
+  while (!worklist.empty()) {
+    Value next = worklist.pop_back_val();
+    auto it = inputToOperands.find(next);
+    if (it == inputToOperands.end()) {
+      reachableValues.insert(next);
+      continue;
+    }
+    for (OpOperand *operand : it->second)
+      if (visited.insert(operand->get()).second)
+        worklist.push_back(operand->get());
+  }
+  // Note: The result does not contain any successor inputs. (Therefore,
+  // `value` is also guaranteed to be excluded.)
+  return reachableValues;
+}
+
+namespace {
+/// Try to make successor inputs dead by replacing their uses with values that
+/// are not successor inputs. This pattern enables additional canonicalization
+/// opportunities for RemoveDeadRegionBranchOpSuccessorInputs.
+///
+/// Example:
+///
+/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
+///   scf.yield %arg1, %arg1 : ...
+/// }
+/// use(%r0, %r1)
+///
+/// reachableValues(%r0) = {%0, %1}
+/// reachableValues(%r1) = {%1} ==> replace uses of %r1 with %1.
+/// reachableValues(%arg0) = {%0, %1}
+/// reachableValues(%arg1) = {%1} ==> replace uses of %arg1 with %1.
+///
+/// IR after pattern application:
+///
+/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
+///   scf.yield %1, %1 : ...
+/// }
+/// use(%r0, %1)
+///
+/// Note that %r1 and %arg1 are dead now. The IR can now be further
+/// canonicalized by RemoveDeadRegionBranchOpSuccessorInputs.
+struct MakeRegionBranchOpSuccessorInputsDead : public RewritePattern {
+  MakeRegionBranchOpSuccessorInputsDead(MLIRContext *context, StringRef name,
+                                        PatternBenefit benefit = 1)
+      : RewritePattern(name, benefit, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+           "isolated-from-above ops are not supported");
+
+    // Compute the mapping of successor inputs to successor operands.
+    auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+    RegionBranchInverseSuccessorMapping inputToOperands;
+    regionBranchOp.getSuccessorInputOperandMapping(inputToOperands);
+
+    // Try to replace the uses of each successor input one-by-one.
+    bool changed = false;
+    for (Value value : inputToOperands.keys()) {
+      // Nothing to do for successor inputs that are already dead.
+      if (value.use_empty())
+        continue;
+      // Nothing to do for successor inputs that may have multiple reachable
+      // values.
+      llvm::SmallDenseSet<Value> reachableValues =
+          computeReachableValuesFromSuccessorInput(value, inputToOperands);
+      if (reachableValues.size() != 1)
+        continue;
+      assert(*reachableValues.begin() != value &&
+             "successor inputs are supposed to be excluded");
+      // Do not replace `value` with the found reachable value if doing so
+      // would violate dominance. Example:
+      // %r = scf.execute_region ... {
+      //   %a = ...
+      //   scf.yield %a : ...
+      // }
+      // use(%r)
+      // In the above example, reachableValues(%r) = {%a}, but %a cannot be
+      // used as a replacement for %r due to dominance / scope.
+      if (!isDefinedBefore(regionBranchOp, *reachableValues.begin(), value))
+        continue;
+      rewriter.replaceAllUsesWith(value, *reachableValues.begin());
+      changed = true;
+    }
+    return success(changed);
+  }
+};
+
+/// Lookup a bit vector in the given mapping (DenseMap). If the key was not
+/// found, create a new bit vector with the given size and initialize it with
+/// false.
+template <typename MappingTy, typename KeyTy>
+static BitVector &lookupOrCreateBitVector(MappingTy &mapping, KeyTy key,
+                                          unsigned size) {
+  return mapping.try_emplace(key, size, false).first->second;
+}
+
+/// Compute tied successor inputs. Tied successor inputs are successor inputs
+/// that come as a set. If you erase one value from a set, you must erase all
+/// values from the set. Otherwise, the op would become structurally invalid.
+/// Each successor input appears in exactly one set.
+///
+/// Example:
+/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
+///   ...
+/// }
+/// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}.
+static llvm::EquivalenceClasses<Value> computeTiedSuccessorInputs(
+    const RegionBranchSuccessorMapping &operandToInputs) {
+  llvm::EquivalenceClasses<Value> tiedSuccessorInputs;
+  for (const auto &[operand, inputs] : operandToInputs) {
+    assert(!inputs.empty() && "expected non-empty inputs");
+    Value firstInput = inputs.front();
+    tiedSuccessorInputs.insert(firstInput);
+    for (Value nextInput : llvm::drop_begin(inputs)) {
+      // As we explore more successor operand to successor input mappings,
+      // existing sets may get merged.
+      tiedSuccessorInputs.unionSets(firstInput, nextInput);
+    }
+  }
+  return tiedSuccessorInputs;
+}
+
+/// Remove dead successor inputs from region branch ops. A successor input is
+/// dead if it has no uses. Successor inputs come in sets of tied values: if
+/// you remove one value from a set, you must remove all values from the set.
+/// Furthermore, successor operands must also be removed. (Op operands are not
+/// part of the set, but the set is built based on the successor operand to
+/// successor input mapping.)
+///
+/// Example 1:
+/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
+///   scf.yield %0, %arg1 : ...
+/// }
+/// use(%0, %1)
+///
+/// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}. All values in the first
+/// set are dead, so %arg0 and %r0 can be removed, but not %r1 and %arg1. The
+/// resulting IR is as follows:
+///
+/// %r1 = scf.for ... iter_args(%arg1 = %1) -> ... {
+///   scf.yield %arg1 : ...
+/// }
+/// use(%0, %1)
+///
+/// Example 2:
+/// %r0, %r1 = scf.while (%arg0 = %0) {
+///   scf.condition(...) %arg0, %arg0 : ...
+/// } do {
+/// ^bb0(%arg1: ..., %arg2: ...):
+///   scf.yield %arg1 : ...
+/// }
+/// There are three sets: {{%r0, %arg1}, {%r1, %arg2}, {%r0}}.
+///
+/// Example 3:
+/// %r1, %r2 = scf.if ... {
+///   scf.yield %0, %1 : ...
+/// } else {
+///   scf.yield %2, %3 : ...
+/// }
+/// There are two sets: {{%r1}, {%r2}}. Each set has one value, so there each
+/// value can be removed independently of the other values.
+struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern {
+  RemoveDeadRegionBranchOpSuccessorInputs(MLIRContext *context, StringRef name,
+                                          PatternBenefit benefit = 1)
+      : RewritePattern(name, benefit, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+           "isolated-from-above ops are not supported");
+
+    // Compute tied values: values that must come as a set. If you remove one,
+    // you must remove all. If a successor op operand is forwarded to two
+    // successor inputs %a and %b, both %a and %b are in the same set.
+    auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+    RegionBranchSuccessorMapping operandToInputs;
+    regionBranchOp.getSuccessorOperandInputMapping(operandToInputs);
+    llvm::EquivalenceClasses<Value> tiedSuccessorInputs =
+        computeTiedSuccessorInputs(operandToInputs);
+
+    // Determine which values to remove and group them by block and operation.
+    SmallVector<Value> valuesToRemove;
+    DenseMap<Block *, BitVector> blockArgsToRemove;
+    BitVector resultsToRemove(regionBranchOp->getNumResults(), false);
+    // Iterate over all sets of tied successor inputs.
+    for (auto it = tiedSuccessorInputs.begin(), e = tiedSuccessorInputs.end();
+         it != e; ++it) {
+      if (!(*it)->isLeader())
+        continue;
+
+      // Value can be removed if it is dead and all other tied values are also
+      // dead.
+      bool allDead = true;
+      for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
+           memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
+        // Iterate over all values in the set and check their liveness.
+        if (!memberIt->use_empty()) {
+          allDead = false;
+          break;
+        }
+      }
+      if (!allDead)
+        continue;
+
+      // The entire set is dead. Group values by block and operation to
+      // simplify removal.
+      for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
+           memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
+        if (auto arg = dyn_cast<BlockArgument>(*memberIt)) {
+          // Set blockArgsToRemove[block][arg_number] = true.
+          BitVector &vector =
+              lookupOrCreateBitVector(blockArgsToRemove, arg.getOwner(),
+                                      arg.getOwner()->getNumArguments());
+          vector.set(arg.getArgNumber());
+        } else {
+          // Set resultsToRemove[result_number] = true.
+          OpResult result = cast<OpResult>(*memberIt);
+          assert(result.getDefiningOp() == regionBranchOp &&
+                 "result must be a region branch op result");
+          resultsToRemove.set(result.getResultNumber());
+        }
+        valuesToRemove.push_back(*memberIt);
+      }
+    }
+
+    if (valuesToRemove.empty())
+      return rewriter.notifyMatchFailure(op, "no values to remove");
+
+    // Find operands that must be removed together with the values.
+    RegionBranchInverseSuccessorMapping inputsToOperands =
+        invertRegionBranchSuccessorMapping(operandToInputs);
+    DenseMap<Operation *, llvm::BitVector> operandsToRemove;
+    for (Value value : valuesToRemove) {
+      for (OpOperand *operand : inputsToOperands[value]) {
+        // Set operandsToRemove[op][operand_number] = true.
+        BitVector &vector =
+            lookupOrCreateBitVector(operandsToRemove, operand->getOwner(),
+                                    operand->getOwner()->getNumOperands());
+        vector.set(operand->getOperandNumber());
+      }
+    }
+
+    // Erase operands.
+    for (auto &pair : operandsToRemove) {
+      Operation *op = pair.first;
+      BitVector &operands = pair.second;
+      rewriter.modifyOpInPlace(op, [&]() { op->eraseOperands(operands); });
+    }
+
+    // Erase block arguments.
+    for (auto &pair : blockArgsToRemove) {
+      Block *block = pair.first;
+      BitVector &blockArg = pair.second;
+      rewriter.modifyOpInPlace(block->getParentOp(),
+                               [&]() { block->eraseArguments(blockArg); });
+    }
+
+    // Erase op results.
+    if (resultsToRemove.any())
+      rewriter.eraseOpResults(regionBranchOp, resultsToRemove);
+
+    return success();
+  }
+};
+
+/// Return "true" if the two values are owned by the same operation or block.
+static bool haveSameOwner(Value a, Value b) {
+  void *aOwner, *bOwner;
+  if (auto arg = dyn_cast<BlockArgument>(a))
+    aOwner = arg.getOwner();
+  else
+    aOwner = a.getDefiningOp();
+  if (auto arg = dyn_cast<BlockArgument>(b))
+    bOwner = arg.getOwner();
+  else
+    bOwner = b.getDefiningOp();
+  return aOwner == bOwner;
+}
+
+/// Get the block argument or op result number of the given value.
+static unsigned getArgOrResultNumber(Value value) {
+  if (auto opResult = llvm::dyn_cast<OpResult>(value))
+    return opResult.getResultNumber();
+  return llvm::cast<BlockArgument>(value).getArgNumber();
+}
+
+/// Find duplicate successor inputs and make all dead except for one. Two
+/// successor inputs are "duplicate" if their corresponding successor operands
+/// have the same values. This pattern enables additional canonicalization
+/// opportunities for RemoveDeadRegionBranchOpSuccessorInputs.
+///
+/// Example:
+/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %0) -> ... {
+///   use(%arg0, %arg1)
+///   ...
+///   scf.yield %x, %x : ...
+/// }
+/// use(%r0, %r1)
+///
+/// Operands of successor input %r0: [%0, %x]
+/// Operands of successor input %r1: [%0, %x] ==> DUPLICATE!
+/// Replace %r1 with %r0.
+///
+/// Operands of successor input %arg0: [%0, %x]
+/// Operands of successor input %arg1: [%0, %x] ==> DUPLICATE!
+/// Replace %arg1 with %arg0. (We have to make sure that we make same decision
+/// as for the other tied successor inputs above. Otherwise, a set of tied
+/// successor inputs may not become entirely dead.)
+///
+/// The resulting IR is as follows:
+/// %r1, %r2 = scf.for ... iter_args(%arg0 = %0, %arg1 = %0) -> ... {
+///   use(%arg0, %arg0)
+///   ...
+///   scf.yield %x, %x : ...
+/// }
+/// use(%r0, %r0)  // Note: We don't want use(%r1, %r1), which is also correct,
+///                // but does not help with further canonicalizations.
+struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
+  RemoveDuplicateSuccessorInputUses(MLIRContext *context, StringRef name,
+                                    PatternBenefit benefit = 1)
+      : RewritePattern(name, benefit, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+           "isolated-from-above ops are not supported");
+
+    // Collect all successor inputs and sort them. When dropping the uses of a
+    // successor input, we'd like to also drop the uses of the same tied
+    // successor inputs. Otherwise, a set of tied successor inputs may not
+    // become entirely dead, which is required for
+    // RemoveDeadRegionBranchOpSuccessorInputs to be able to erase them.
+    // (Sorting is not required for correctness.)
+    auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+    RegionBranchInverseSuccessorMapping inputsToOperands;
+    regionBranchOp.getSuccessorInputOperandMapping(inputsToOperands);
+    SmallVector<Value> inputs = llvm::to_vector(inputsToOperands.keys());
+    llvm::sort(inputs, [](Value a, Value b) {
+      return getArgOrResultNumber(a) < getArgOrResultNumber(b);
+    });
+
+    // Check every distinct pair of successor inputs for duplicates. Replace
+    // `input2` with `input1` if they are duplicates.
+    bool changed = false;
+    unsigned numInputs = inputs.size();
+    for (auto i : llvm::seq<unsigned>(0, numInputs)) {
+      Value input1 = inputs[i];
+      for (auto j : llvm::seq<unsigned>(i + 1, numInputs)) {
+        Value input2 = inputs[j];
+        // Nothing to do if input2 is already dead.
+        if (input2.use_empty())
+          continue;
+        // Replace only values that belong to the same block / operation.
+        // This implies that the two values are either both block arguments or
+        // both op results.
+        if (!haveSameOwner(input1, input2))
+          continue;
+
+        // Gather the predecessor value for each predecessor (region branch
+        // point). The two inputs are duplicates if each predecessor forwards
+        // the same value.
+        DenseMap<Operation *, Value> operands1, operands2;
----------------
Hardcode84 wrote:

`SmallDenseMap`s

https://github.com/llvm/llvm-project/pull/174094


More information about the Mlir-commits mailing list