[Mlir-commits] [mlir] 0b24580 - [mlir][Interfaces][NFC] Add `RegionBranchOpInterface` helper for forwarded values (#173981)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 1 02:56:09 PST 2026


Author: Matthias Springer
Date: 2026-01-01T11:56:05+01:00
New Revision: 0b24580a26644c8a44abc6fba97609552ba41d42

URL: https://github.com/llvm/llvm-project/commit/0b24580a26644c8a44abc6fba97609552ba41d42
DIFF: https://github.com/llvm/llvm-project/commit/0b24580a26644c8a44abc6fba97609552ba41d42.diff

LOG: [mlir][Interfaces][NFC] Add `RegionBranchOpInterface` helper for forwarded values (#173981)

Add a helper function to compute a mapping of successor operands to
successor inputs. This mapping is computed in various places. Also add a
helper function to gather all region branch points.

This commit is in preparation of a bug fix / partial redesign of
`-remove-dead-values`. This commit also removes some duplicate code in
various places.

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
    mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
    mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
    mlir/lib/Interfaces/ControlFlowInterfaces.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index bfc24c18429ed..566f4b8fadb5d 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -176,6 +176,19 @@ namespace detail {
 LogicalResult verifyTypesAlongControlFlowEdges(Operation *op);
 } //  namespace detail
 
+/// A mapping from successor operands to successor inputs.
+///
+/// * A successor operand is an operand of a region branch op or region
+///   branch terminator, that is forwarded to a successor input.
+/// * A successor input is a block argument of a region or a result of the
+///   region branch op, that is populated by a successor operand.
+///
+/// The mapping is 1:N. Each successor operand may be forwarded to multiple
+/// successor inputs. (Because the control flow can dispatch to multiple
+/// possible successors.) Operands that not forwarded at all are not present in
+/// the mapping.
+using RegionBranchSuccessorMapping = DenseMap<OpOperand *, SmallVector<Value>>;
+
 /// This class represents a successor of a region. A region successor can either
 /// be another region, or the parent operation. If the successor is a region,
 /// this class represents the destination region, as well as a set of arguments

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 8760c8b8715f9..2e654ba04ffe5 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -251,30 +251,16 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
            "::llvm::SmallVectorImpl<::mlir::RegionBranchPoint> &":$predecessors),
       [{}],
       /*defaultImplementation=*/[{
-        ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
-        $_op.getSuccessorRegions(RegionBranchPoint::parent(),
-                                 successors);
-        if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
+        auto op = cast<RegionBranchOpInterface>($_op.getOperation());
+        for (::mlir::RegionBranchPoint point : op.getAllRegionBranchPoints()) {
+          ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
+          op.getSuccessorRegions(point, successors);
+          bool isPred = llvm::any_of(successors, [&] (const auto &succ) {
             return succ.getSuccessor() == successor.getSuccessor() ||
-              (succ.isParent() && successor.isParent());
-          }))
-          predecessors.push_back(RegionBranchPoint::parent());
-        for (Region &region : $_op->getRegions()) {
-          for (::mlir::Block &block : region) {
-            if (block.empty())
-              continue;
-            if (auto terminator =
-                    dyn_cast<RegionBranchTerminatorOpInterface>(block.back())) {
-              ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
-              $_op.getSuccessorRegions(RegionBranchPoint(terminator),
-                                       successors);
-              if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
-                  return succ.getSuccessor() == successor.getSuccessor() ||
-                    (succ.isParent() && successor.isParent());
-                }))
-                predecessors.push_back(terminator);
-            }
-          }
+                   (succ.isParent() && successor.isParent());
+            });
+          if (isPred)
+            predecessors.push_back(point);
         }
       }]>,
     InterfaceMethod<[{
@@ -359,6 +345,19 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     /// of the respective terminator.
     ::mlir::OperandRange getSuccessorOperands(
         ::mlir::RegionBranchPoint src, ::mlir::RegionSuccessor dest);
+
+    /// Build a mapping from successor operands to successor input. Each
+    /// successor operand could be forwarded to multiple successor inputs.
+    /// Operands that are not forwarded are not added to the map. Unless a
+    /// specific region branch point is specified, this function takes into
+    /// account all possible region branch points.
+    void getSuccessorOperandInputMapping(
+        ::mlir::RegionBranchSuccessorMapping &mapping,
+        std::optional<::mlir::RegionBranchPoint> src = std::nullopt);
+
+    /// Return all possible region branch points: the region branch op itself
+    /// and all region branch terminators.
+    ::llvm::SmallVector<::mlir::RegionBranchPoint> getAllRegionBranchPoints();
   }];
 }
 

diff  --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 8e63ae86753b4..64adb8cf00ba7 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -591,29 +591,24 @@ LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation(
 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
     RegionBranchOpInterface branch,
     ArrayRef<AbstractSparseLattice *> operandLattices) {
-  Operation *op = branch.getOperation();
-  SmallVector<RegionSuccessor> successors;
-  SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
-  branch.getEntrySuccessorRegions(operands, successors);
-
-  // All operands not forwarded to any successor. This set can be non-contiguous
-  // in the presence of multiple successors.
-  BitVector unaccounted(op->getNumOperands(), true);
-
-  for (RegionSuccessor &successor : successors) {
-    OperandRange operands = branch.getEntrySuccessorOperands(successor);
-    MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
-    ValueRange inputs = successor.getSuccessorInputs();
-    for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
-      meet(getLatticeElement(operand.get()),
-           *getLatticeElementFor(getProgramPointAfter(op), input));
-      unaccounted.reset(operand.getOperandNumber());
+  // Not all operands are forwarded to a successor. This set can be
+  // non-contiguous in the presence of multiple successors.
+  BitVector unaccounted(branch->getNumOperands(), true);
+
+  RegionBranchSuccessorMapping mapping;
+  branch.getSuccessorOperandInputMapping(mapping, RegionBranchPoint::parent());
+  for (const auto &[operand, inputs] : mapping) {
+    for (Value input : inputs) {
+      meet(getLatticeElement(operand->get()),
+           *getLatticeElementFor(getProgramPointAfter(branch), input));
+      unaccounted.reset(operand->getOperandNumber());
     }
   }
+
   // All operands not forwarded to regions are typically parameters of the
   // branch operation itself (for example the boolean for if/else).
   for (int index : unaccounted.set_bits()) {
-    visitBranchOperand(op->getOpOperand(index));
+    visitBranchOperand(branch->getOpOperand(index));
   }
 }
 
@@ -626,24 +621,21 @@ void AbstractSparseBackwardDataFlowAnalysis::
   assert(terminator->getParentOp() == branch.getOperation() &&
          "expected `branch` to be the parent op of `terminator`");
 
-  SmallVector<Attribute> operandAttributes(terminator->getNumOperands(),
-                                           nullptr);
-  SmallVector<RegionSuccessor> successors;
-  terminator.getSuccessorRegions(operandAttributes, successors);
-  // All operands not forwarded to any successor. This set can be
+  // Not all operands are forwarded to a successor. This set can be
   // non-contiguous in the presence of multiple successors.
   BitVector unaccounted(terminator->getNumOperands(), true);
 
-  for (const RegionSuccessor &successor : successors) {
-    ValueRange inputs = successor.getSuccessorInputs();
-    OperandRange operands = terminator.getSuccessorOperands(successor);
-    MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
-    for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
-      meet(getLatticeElement(opOperand.get()),
+  RegionBranchSuccessorMapping mapping;
+  branch.getSuccessorOperandInputMapping(mapping,
+                                         RegionBranchPoint(terminator));
+  for (const auto &[operand, inputs] : mapping) {
+    for (Value input : inputs) {
+      meet(getLatticeElement(operand->get()),
            *getLatticeElementFor(getProgramPointAfter(terminator), input));
-      unaccounted.reset(const_cast<OpOperand &>(opOperand).getOperandNumber());
+      unaccounted.reset(operand->getOperandNumber());
     }
   }
+
   // Visit operands of the branch op not forwarded to the next region.
   // (Like e.g. the boolean of `scf.conditional`)
   for (int index : unaccounted.set_bits()) {

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 0b2e080e52b75..ac94ef9d866fa 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -142,37 +142,12 @@ void BufferViewFlowAnalysis::build(Operation *op) {
     }
 
     if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) {
-      // Query the RegionBranchOpInterface to find potential successor regions.
-      // Extract all entry regions and wire all initial entry successor inputs.
-      SmallVector<RegionSuccessor, 2> entrySuccessors;
-      regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
-                                          entrySuccessors);
-      for (RegionSuccessor &entrySuccessor : entrySuccessors) {
-        // Wire the entry region's successor arguments with the initial
-        // successor inputs.
-        registerDependencies(
-            regionInterface.getEntrySuccessorOperands(entrySuccessor),
-            entrySuccessor.getSuccessorInputs());
-      }
-
-      // Wire flow between regions and from region exits.
-      for (Region &region : regionInterface->getRegions()) {
-        // Iterate over all successor region entries that are reachable from the
-        // current region.
-        SmallVector<RegionSuccessor, 2> successorRegions;
-        regionInterface.getSuccessorRegions(region, successorRegions);
-        for (RegionSuccessor &successorRegion : successorRegions) {
-          // Iterate over all immediate terminator operations and wire the
-          // successor inputs with the successor operands of each terminator.
-          for (Block &block : region)
-            if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
-                    block.getTerminator()))
-              registerDependencies(
-                  terminator.getSuccessorOperands(successorRegion),
-                  successorRegion.getSuccessorInputs());
-        }
-      }
-
+      // Wire the successor operands with the successor inputs.
+      DenseMap<OpOperand *, SmallVector<Value>> mapping;
+      regionInterface.getSuccessorOperandInputMapping(mapping);
+      for (const auto &[operand, inputs] : mapping)
+        for (Value input : inputs)
+          registerDependencies({operand->get()}, {input});
       return WalkResult::advance();
     }
 

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 7fc75e7294ea3..60fb13c7c2cd7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1248,20 +1248,15 @@ updateControlFlowOps(mlir::OpBuilder &builder,
                      mlir::RegionBranchTerminatorOpInterface terminator,
                      GetLayoutFnTy getLayoutOfValue) {
   // Only process if the terminator is inside a region branch op.
-  if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
+  auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
+  if (!branchOp)
     return success();
 
-  llvm::SmallVector<mlir::RegionSuccessor> successors;
-  llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
-                                              nullptr);
-  terminator.getSuccessorRegions(operands, successors);
-
-  for (mlir::RegionSuccessor &successor : successors) {
-    mlir::OperandRange successorOperands =
-        terminator.getSuccessorOperands(successor);
-    mlir::ValueRange successorInputs = successor.getSuccessorInputs();
-    for (auto [successorOperand, successorInput] :
-         llvm::zip(successorOperands, successorInputs)) {
+  RegionBranchSuccessorMapping mapping;
+  branchOp.getSuccessorOperandInputMapping(mapping,
+                                           RegionBranchPoint(terminator));
+  for (const auto &[successorOperand, successorInputs] : mapping) {
+    for (Value successorInput : successorInputs) {
       Type inputType = successorInput.getType();
       // We only need to operate on tensor descriptor or vector types.
       if (!isa<xegpu::TensorDescType, VectorType>(inputType))
@@ -1269,13 +1264,13 @@ updateControlFlowOps(mlir::OpBuilder &builder,
       xegpu::DistributeLayoutAttr successorInputLayout =
           getLayoutOfValue(successorInput);
       xegpu::DistributeLayoutAttr successorOperandLayout =
-          getLayoutOfValue(successorOperand);
+          getLayoutOfValue(successorOperand->get());
 
       // If either of the layouts is not assigned, we cannot proceed.
       if (!successorOperandLayout) {
         LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
                              "branch terminator: "
-                          << successorOperand << "\n");
+                          << successorOperand->get() << "\n");
         return failure();
       }
       // We expect the layouts to match.

diff  --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 51861d7751450..d393ddb8d8336 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -489,6 +489,50 @@ RegionBranchOpInterface::getSuccessorOperands(RegionBranchPoint src,
   return terminator.getSuccessorOperands(dest);
 }
 
+static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
+  return MutableArrayRef<OpOperand>(operands.getBase(), operands.size());
+}
+
+static void
+getSuccessorOperandInputMapping(RegionBranchOpInterface branchOp,
+                                RegionBranchSuccessorMapping &mapping,
+                                RegionBranchPoint src) {
+  SmallVector<RegionSuccessor> successors;
+  branchOp.getSuccessorRegions(src, successors);
+  for (RegionSuccessor dst : successors) {
+    OperandRange operands = branchOp.getSuccessorOperands(src, dst);
+    assert(operands.size() == dst.getSuccessorInputs().size() &&
+           "expected the same number of operands and inputs");
+    for (const auto &[operand, input] : llvm::zip_equal(
+             operandsToOpOperands(operands), dst.getSuccessorInputs()))
+      mapping[&operand].push_back(input);
+  }
+}
+void RegionBranchOpInterface::getSuccessorOperandInputMapping(
+    RegionBranchSuccessorMapping &mapping,
+    std::optional<RegionBranchPoint> src) {
+  if (src.has_value()) {
+    ::getSuccessorOperandInputMapping(*this, mapping, src.value());
+  } else {
+    // No region branch point specified: populate the mapping for all possible
+    // region branch points.
+    for (RegionBranchPoint branchPoint : getAllRegionBranchPoints())
+      ::getSuccessorOperandInputMapping(*this, mapping, branchPoint);
+  }
+}
+
+SmallVector<RegionBranchPoint>
+RegionBranchOpInterface::getAllRegionBranchPoints() {
+  SmallVector<RegionBranchPoint> branchPoints;
+  branchPoints.push_back(RegionBranchPoint::parent());
+  for (Region &region : getOperation()->getRegions())
+    for (Block &block : region)
+      if (auto terminator =
+              dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
+        branchPoints.push_back(RegionBranchPoint(terminator));
+  return branchPoints;
+}
+
 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
   LDBG() << "Finding enclosing repetitive region for operation "
          << op->getName();


        


More information about the Mlir-commits mailing list