[Mlir-commits] [mlir] b26bb30 - Revert "[mlir] Use a type for representing branch points in `RegionBranchOpInterface`"

Markus Böck llvmlistbot at llvm.org
Tue Aug 29 11:46:12 PDT 2023


Author: Markus Böck
Date: 2023-08-29T20:17:50+02:00
New Revision: b26bb30b467b996c9786e3bd426c07684d84d406

URL: https://github.com/llvm/llvm-project/commit/b26bb30b467b996c9786e3bd426c07684d84d406
DIFF: https://github.com/llvm/llvm-project/commit/b26bb30b467b996c9786e3bd426c07684d84d406.diff

LOG: Revert "[mlir] Use a type for representing branch points in `RegionBranchOpInterface`"

This reverts commit 024f562da67180b7be1663048c960b26c2cc16f8.

Forgot to update flang

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
    mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
    mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
    mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
    mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Async/IR/Async.cpp
    mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
    mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
    mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Interfaces/ControlFlowInterfaces.cpp
    mlir/lib/Transforms/RemoveDeadValues.cpp
    mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 6a1335bab8bf6e..a3a558f7705074 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -353,8 +353,8 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   /// any effect on the lattice that isn't already expressed by the interface
   /// itself.
   virtual void visitRegionBranchControlFlowTransfer(
-      RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
-      RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+      RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
+      std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
       AbstractDenseLattice *before) {
     meet(before, after);
   }
@@ -382,7 +382,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   /// of the branch operation itself.
   void visitRegionBranchOperation(ProgramPoint point,
                                   RegionBranchOpInterface branch,
-                                  RegionBranchPoint branchPoint,
+                                  std::optional<unsigned> regionNo,
                                   AbstractDenseLattice *before);
 
   /// Visit an operation for which the data flow is described by the
@@ -472,8 +472,9 @@ class DenseBackwardDataFlowAnalysis
   /// nullptr`. The behavior can be further refined for specific pairs of "from"
   /// and "to" regions.
   virtual void visitRegionBranchControlFlowTransfer(
-      RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
-      RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
+      RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
+      std::optional<unsigned> regionTo, const LatticeT &after,
+      LatticeT *before) {
     AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
         branch, regionFrom, regionTo, after, before);
   }
@@ -507,8 +508,8 @@ class DenseBackwardDataFlowAnalysis
                                  static_cast<LatticeT *>(before));
   }
   void visitRegionBranchControlFlowTransfer(
-      RegionBranchOpInterface branch, RegionBranchPoint regionForm,
-      RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+      RegionBranchOpInterface branch, std::optional<unsigned> regionForm,
+      std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
       AbstractDenseLattice *before) final {
     visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
                                          static_cast<const LatticeT &>(after),

diff  --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 5a9a36159b56c5..13dacff3aa0422 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -243,7 +243,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
   /// regions or the parent operation itself, and set either the argument or
   /// parent result lattices.
   void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
-                             RegionBranchPoint successor,
+                             std::optional<unsigned> successorIndex,
                              ArrayRef<AbstractSparseLattice *> lattices);
 };
 

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 006aedced839f9..bd81da41ed43cd 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -190,68 +190,6 @@ class RegionSuccessor {
   ValueRange inputs;
 };
 
-/// This class represents a point being branched from in the methods of the
-/// `RegionBranchOpInterface`.
-/// One can branch from one of two kinds of places:
-/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
-/// * A region within the parent operation.
-class RegionBranchPoint {
-public:
-  /// Returns an instance of `RegionBranchPoint` representing the parent
-  /// operation.
-  static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); }
-
-  /// Creates a `RegionBranchPoint` that branches from the given region.
-  /// The pointer must not be null.
-  RegionBranchPoint(Region *region) : maybeRegion(region) {
-    assert(region && "Region must not be null");
-  }
-
-  RegionBranchPoint(Region &region) : RegionBranchPoint(&region) {}
-
-  /// Explicitly stops users from constructing with `nullptr`.
-  RegionBranchPoint(std::nullptr_t) = delete;
-
-  /// Constructs a `RegionBranchPoint` from the the target of a
-  /// `RegionSuccessor` instance.
-  RegionBranchPoint(RegionSuccessor successor) {
-    if (successor.isParent())
-      maybeRegion = nullptr;
-    else
-      maybeRegion = successor.getSuccessor();
-  }
-
-  /// Assigns a region being branched from.
-  RegionBranchPoint &operator=(Region &region) {
-    maybeRegion = ®ion;
-    return *this;
-  }
-
-  /// Returns true if branching from the parent op.
-  bool isParent() const { return maybeRegion == nullptr; }
-
-  /// Returns the region if branching from a region.
-  /// A null pointer otherwise.
-  Region *getRegionOrNull() const { return maybeRegion; }
-
-  /// Returns true if the two branch points are equal.
-  friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
-    return lhs.maybeRegion == rhs.maybeRegion;
-  }
-
-private:
-  // Private constructor to encourage the use of `RegionBranchPoint::parent`.
-  constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
-
-  /// Internal encoding. Uses nullptr for representing branching from the parent
-  /// op and the region being branched from otherwise.
-  Region *maybeRegion;
-};
-
-inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
-  return !(lhs == rhs);
-}
-
 /// This class represents upper and lower bounds on the number of times a region
 /// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
 /// zero, but the upper bound may not be known.

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index e52636a5ac8fcc..132bd6d53d923a 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -133,14 +133,14 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
   let methods = [
     InterfaceMethod<[{
         Returns the operands of this operation used as the entry arguments when
-        branching from `point`, which was specified as a successor of
+        entering the region at `index`, which was specified as a successor of
         this operation by `getEntrySuccessorRegions`, or the operands forwarded
         to the operation's results when it branches back to itself. These operands
         should correspond 1-1 with the successor inputs specified in
         `getEntrySuccessorRegions`.
       }],
       "::mlir::OperandRange", "getEntrySuccessorOperands",
-      (ins "::mlir::RegionBranchPoint":$point), [{}],
+      (ins "::std::optional<unsigned>":$index), [{}],
       /*defaultImplementation=*/[{
         auto operandEnd = this->getOperation()->operand_end();
         return ::mlir::OperandRange(operandEnd, operandEnd);
@@ -162,20 +162,22 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
       (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
            "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
            [{}], [{
-        $_op.getSuccessorRegions(mlir::RegionBranchPoint::parent(), regions);
+        $_op.getSuccessorRegions(std::nullopt, regions);
       }]
     >,
     InterfaceMethod<[{
-        Returns the viable successors of `point`. These are the regions that may
-        be selected during the flow of control. The parent operation, may
-        specify itself as successor, which indicates that the control flow may
-        not enter any region at all. This method allows for describing which
-        regions may be executed when entering an operation, and which regions
-        are executed after having executed another region of the parent op. The
-        successor region must be non-empty.
+        Returns the viable successors of a region at `index`, or the possible
+        successors when branching from the parent op if `index` is None. These
+        are the regions that may be selected during the flow of control. The
+        parent operation, i.e. a null `index`, may specify itself as successor,
+        which indicates that the control flow may not enter any region at all.
+        This method allows for describing which regions may be executed when
+        entering an operation, and which regions are executed after having
+        executed another region of the parent op. The successor region must be
+        non-empty.
       }],
       "void", "getSuccessorRegions",
-      (ins "::mlir::RegionBranchPoint":$point,
+      (ins "::std::optional<unsigned>":$index,
            "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
     >,
     InterfaceMethod<[{
@@ -243,10 +245,12 @@ def RegionBranchTerminatorOpInterface :
   let methods = [
     InterfaceMethod<[{
         Returns a mutable range of operands that are semantically "returned" by
-        passing them to the region successor given by `point`.
+        passing them to the region successor given by `index`.  If `index` is
+        None, this function returns the operands that are passed as a result to
+        the parent operation.
       }],
       "::mlir::MutableOperandRange", "getMutableSuccessorOperands",
-      (ins "::mlir::RegionBranchPoint":$point)
+      (ins "::std::optional<unsigned>":$index)
     >,
     InterfaceMethod<[{
         Returns the viable region successors that are branched to after this
@@ -265,7 +269,8 @@ def RegionBranchTerminatorOpInterface :
       [{
         ::mlir::Operation *op = $_op;
         ::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
-          .getSuccessorRegions(op->getParentRegion(), regions);
+          .getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
+            regions);
       }]
     >,
   ];
@@ -285,8 +290,8 @@ def RegionBranchTerminatorOpInterface :
     // them to the region successor given by `index`.  If `index` is None, this
     // function returns the operands that are passed as a result to the parent
     // operation.
-    ::mlir::OperandRange getSuccessorOperands(::mlir::RegionBranchPoint point) {
-      return getMutableSuccessorOperands(point);
+    ::mlir::OperandRange getSuccessorOperands(std::optional<unsigned> index) {
+      return getMutableSuccessorOperands(index);
     }
   }];
 }
@@ -304,7 +309,7 @@ def ReturnLike : TraitList<[
         /*extraOpDeclaration=*/"",
         /*extraOpDefinition=*/[{
           ::mlir::MutableOperandRange $cppClass::getMutableSuccessorOperands(
-            ::mlir::RegionBranchPoint point) {
+            ::std::optional<unsigned> index) {
             return ::mlir::MutableOperandRange(*this);
           }
         }]

diff  --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index ae2ba90412137c..970e68bc258649 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -45,9 +45,9 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
   // this region predecessor that correspond to the input values of `region`. If
   // an index could not be found, std::nullopt is returned instead.
   auto getOperandIndexIfPred =
-      [&](RegionBranchPoint pred) -> std::optional<unsigned> {
+      [&](std::optional<unsigned> predIndex) -> std::optional<unsigned> {
     SmallVector<RegionSuccessor, 2> successors;
-    branch.getSuccessorRegions(pred, successors);
+    branch.getSuccessorRegions(predIndex, successors);
     for (RegionSuccessor &successor : successors) {
       if (successor.getSuccessor() != region)
         continue;
@@ -75,27 +75,28 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
   };
 
   // Check branches from the parent operation.
-  auto branchPoint = RegionBranchPoint::parent();
-  if (region)
-    branchPoint = region;
-
+  std::optional<unsigned> regionIndex;
+  if (region) {
+    // Determine the actual region number from the passed region.
+    regionIndex = region->getRegionNumber();
+  }
   if (std::optional<unsigned> operandIndex =
-          getOperandIndexIfPred(/*predIndex=*/RegionBranchPoint::parent())) {
+          getOperandIndexIfPred(/*predIndex=*/std::nullopt)) {
     collectUnderlyingAddressValues(
-        branch.getEntrySuccessorOperands(branchPoint)[*operandIndex], maxDepth,
+        branch.getEntrySuccessorOperands(regionIndex)[*operandIndex], maxDepth,
         visited, output);
   }
   // Check branches from each child region.
   Operation *op = branch.getOperation();
-  for (Region &region : op->getRegions()) {
-    if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(region)) {
-      for (Block &block : region) {
+  for (int i = 0, e = op->getNumRegions(); i != e; ++i) {
+    if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(i)) {
+      for (Block &block : op->getRegion(i)) {
         // Try to determine possible region-branch successor operands for the
         // current region.
         if (auto term = dyn_cast<RegionBranchTerminatorOpInterface>(
                 block.getTerminator())) {
           collectUnderlyingAddressValues(
-              term.getSuccessorOperands(branchPoint)[*operandIndex], maxDepth,
+              term.getSuccessorOperands(regionIndex)[*operandIndex], maxDepth,
               visited, output);
         } else if (block.getNumSuccessors()) {
           // Otherwise, if this terminator may exit the region we can't make

diff  --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index eab408cd5977c3..c79a360e4c11bb 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -312,8 +312,7 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
 
   // Special cases where control flow may dictate data flow.
   if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
-    return visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(),
-                                      before);
+    return visitRegionBranchOperation(op, branch, std::nullopt, before);
   if (auto call = dyn_cast<CallOpInterface>(op))
     return visitCallOperation(call, before);
 
@@ -369,7 +368,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
     // If this block is exiting from an operation with region-based control
     // flow, propagate the lattice back along the control flow edge.
     if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
-      visitRegionBranchOperation(block, branch, block->getParent(), before);
+      visitRegionBranchOperation(block, branch,
+                                 block->getParent()->getRegionNumber(), before);
       return;
     }
 
@@ -396,13 +396,13 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
 
 void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
     ProgramPoint point, RegionBranchOpInterface branch,
-    RegionBranchPoint branchPoint, AbstractDenseLattice *before) {
+    std::optional<unsigned> regionNo, AbstractDenseLattice *before) {
 
   // The successors of the operation may be either the first operation of the
   // entry block of each possible successor region, or the next operation when
   // the branch is a successor of itself.
   SmallVector<RegionSuccessor> successors;
-  branch.getSuccessorRegions(branchPoint, successors);
+  branch.getSuccessorRegions(regionNo, successors);
   for (const RegionSuccessor &successor : successors) {
     const AbstractDenseLattice *after;
     if (successor.isParent() || successor.getSuccessor()->empty()) {
@@ -423,8 +423,10 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
       else
         after = getLatticeFor(point, &successorBlock->front());
     }
-
-    visitRegionBranchControlFlowTransfer(branch, branchPoint, successor, *after,
+    std::optional<unsigned> successorNo =
+        successor.isParent() ? std::optional<unsigned>()
+                             : successor.getSuccessor()->getRegionNumber();
+    visitRegionBranchControlFlowTransfer(branch, regionNo, successorNo, *after,
                                          before);
   }
 }

diff  --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 02a0ce1bb29213..4708cdb042f126 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -99,7 +99,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
   // The results of a region branch operation are determined by control-flow.
   if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
     return visitRegionSuccessors({branch}, branch,
-                                 /*successor=*/RegionBranchPoint::parent(),
+                                 /*successorIndex=*/std::nullopt,
                                  resultLattices);
   }
 
@@ -167,8 +167,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
 
     // Check if the lattices can be determined from region control flow.
     if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
-      return visitRegionSuccessors(block, branch, block->getParent(),
-                                   argLattices);
+      return visitRegionSuccessors(
+          block, branch, block->getParent()->getRegionNumber(), argLattices);
     }
 
     // Otherwise, we can't reason about the data-flow.
@@ -212,7 +212,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
 
 void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
     ProgramPoint point, RegionBranchOpInterface branch,
-    RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
+    std::optional<unsigned> successorIndex,
+    ArrayRef<AbstractSparseLattice *> lattices) {
   const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
   assert(predecessors->allPredecessorsKnown() &&
          "unexpected unresolved region successors");
@@ -223,11 +224,11 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
 
     // Check if the predecessor is the parent op.
     if (op == branch) {
-      operands = branch.getEntrySuccessorOperands(successor);
+      operands = branch.getEntrySuccessorOperands(successorIndex);
       // Otherwise, try to deduce the operands from a region return-like op.
     } else if (auto regionTerminator =
                    dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
-      operands = regionTerminator.getSuccessorOperands(successor);
+      operands = regionTerminator.getSuccessorOperands(successorIndex);
     }
 
     if (!operands) {
@@ -500,7 +501,10 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
   BitVector unaccounted(op->getNumOperands(), true);
 
   for (RegionSuccessor &successor : successors) {
-    OperandRange operands = branch.getEntrySuccessorOperands(successor);
+    Region *region = successor.getSuccessor();
+    OperandRange operands =
+        region ? branch.getEntrySuccessorOperands(region->getRegionNumber())
+               : branch.getEntrySuccessorOperands({});
     MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
     ValueRange inputs = successor.getSuccessorInputs();
     for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
@@ -534,7 +538,9 @@ void AbstractSparseBackwardDataFlowAnalysis::
 
   for (const RegionSuccessor &successor : successors) {
     ValueRange inputs = successor.getSuccessorInputs();
-    OperandRange operands = terminator.getSuccessorOperands(successor);
+    Region *region = successor.getSuccessor();
+    OperandRange operands = terminator.getSuccessorOperands(
+        region ? region->getRegionNumber() : std::optional<unsigned>{});
     MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
     for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
       meet(getLatticeElement(opOperand.get()),

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 9d7b8f371a26c6..bb4aaee21d019e 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2379,9 +2379,9 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
 /// correspond to the loop iterator operands, i.e., those excluding the
 /// induction variable. AffineForOp only has one region, so zero is the only
 /// valid value for `index`.
-OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
-  assert((point.isParent() || point == getLoopBody()) &&
-         "invalid region point");
+OperandRange
+AffineForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
+  assert((!index || *index == 0) && "invalid region index");
 
   // The initial operands map to the loop arguments after the induction
   // variable or are forwarded to the results when the trip count is zero.
@@ -2394,15 +2394,14 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
 /// correspond to a constant value for each operand, or null if that operand is
 /// not a constant.
 void AffineForOp::getSuccessorRegions(
-    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
-  assert((point.isParent() || point == getLoopBody()) &&
-         "expected loop region");
+    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
+  assert((!index.has_value() || index.value() == 0) && "expected loop region");
   // The loop may typically branch back to its body or to the parent operation.
   // If the predecessor is the parent op and the trip count is known to be at
   // least one, branch into the body using the iterator arguments. And in cases
   // we know the trip count is zero, it can only branch back to its parent.
   std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
-  if (point.isParent() && tripCount.has_value()) {
+  if (!index.has_value() && tripCount.has_value()) {
     if (tripCount.value() > 0) {
       regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
       return;
@@ -2415,7 +2414,7 @@ void AffineForOp::getSuccessorRegions(
 
   // From the loop body, if the trip count is one, we can only branch back to
   // the parent.
-  if (!point.isParent() && tripCount && *tripCount == 1) {
+  if (index && tripCount && *tripCount == 1) {
     regions.push_back(RegionSuccessor(getResults()));
     return;
   }
@@ -2860,10 +2859,10 @@ struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
 /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
 /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
 void AffineIfOp::getSuccessorRegions(
-    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
   // If the predecessor is an AffineIfOp, then branching into both `then` and
   // `else` region is valid.
-  if (point.isParent()) {
+  if (!index.has_value()) {
     regions.reserve(2);
     regions.push_back(
         RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));

diff  --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index a05e02faf6d2f0..9b4fb81990c169 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -38,8 +38,9 @@ void AsyncDialect::initialize() {
 
 constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
 
-OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
-  assert(point == getBodyRegion() && "invalid region index");
+OperandRange
+ExecuteOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
+  assert(index && *index == 0 && "invalid region index");
   return getBodyOperands();
 }
 
@@ -52,10 +53,11 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
   return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
 }
 
-void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
+void ExecuteOp::getSuccessorRegions(std::optional<unsigned> index,
                                     SmallVectorImpl<RegionSuccessor> &regions) {
   // The `body` region branch back to the parent operation.
-  if (point == getBodyRegion()) {
+  if (index) {
+    assert(*index == 0 && "invalid region index");
     regions.push_back(RegionSuccessor(getBodyResults()));
     return;
   }

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index 9a831e4c322ea0..582974873263d2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -372,7 +372,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
     // parent operation. In this case, we have to introduce an additional clone
     // for buffer that is passed to the argument.
     SmallVector<RegionSuccessor, 2> successorRegions;
-    regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
+    regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
                                         successorRegions);
     auto *it =
         llvm::find_if(successorRegions, [&](RegionSuccessor &successorRegion) {
@@ -383,7 +383,8 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
 
     // Determine the actual operand to introduce a clone for and rewire the
     // operand to point to the clone instead.
-    auto operands = regionInterface.getEntrySuccessorOperands(argRegion);
+    auto operands =
+        regionInterface.getEntrySuccessorOperands(argRegion->getRegionNumber());
     size_t operandIndex =
         llvm::find(it->getSuccessorInputs(), blockArg).getIndex() +
         operands.getBeginOperandIndex();
@@ -431,7 +432,8 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
       // Query the regionInterface to get all successor regions of the current
       // one.
       SmallVector<RegionSuccessor, 2> successorRegions;
-      regionInterface.getSuccessorRegions(region, successorRegions);
+      regionInterface.getSuccessorRegions(region.getRegionNumber(),
+                                          successorRegions);
       // Try to find a matching region successor.
       RegionSuccessor *regionSuccessor =
           llvm::find_if(successorRegions, regionPredicate);
@@ -443,6 +445,10 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
           llvm::find(regionSuccessor->getSuccessorInputs(), argValue)
               .getIndex();
 
+      std::optional<unsigned> successorRegionNumber;
+      if (Region *successorRegion = regionSuccessor->getSuccessor())
+        successorRegionNumber = successorRegion->getRegionNumber();
+
       // Iterate over all immediate terminator operations to introduce
       // new buffer allocations. Thereby, the appropriate terminator operand
       // will be adjusted to point to the newly allocated buffer instead.
@@ -450,7 +456,8 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
               &region, [&](RegionBranchTerminatorOpInterface terminator) {
                 // Get the actual mutable operands for this terminator op.
                 auto terminatorOperands =
-                    terminator.getMutableSuccessorOperands(*regionSuccessor);
+                    terminator.getMutableSuccessorOperands(
+                        successorRegionNumber);
                 // Extract the source value from the current terminator.
                 // This conversion needs to exist on a separate line due to a
                 // bug in GCC conversion analysis.

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index 119801f9cc92f3..f8231cac778af6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -123,7 +123,7 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
       return true;
     // Recurses into all region successors.
     SmallVector<RegionSuccessor, 2> successors;
-    regionInterface.getSuccessorRegions(current, successors);
+    regionInterface.getSuccessorRegions(current->getRegionNumber(), successors);
     for (RegionSuccessor &regionEntry : successors)
       if (recurse(regionEntry.getSuccessor()))
         return true;
@@ -132,8 +132,7 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
 
   // Start with all entry regions and test whether they induce a loop.
   SmallVector<RegionSuccessor, 2> successorRegions;
-  regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
-                                      successorRegions);
+  regionInterface.getSuccessorRegions(/*index=*/std::nullopt, successorRegions);
   for (RegionSuccessor &regionEntry : successorRegions) {
     if (recurse(regionEntry.getSuccessor()))
       return true;

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 98a60a48763ab1..d201e024380661 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -100,13 +100,16 @@ void BufferViewFlowAnalysis::build(Operation *op) {
       // Query the RegionBranchOpInterface to find potential successor regions.
       // Extract all entry regions and wire all initial entry successor inputs.
       SmallVector<RegionSuccessor, 2> entrySuccessors;
-      regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
+      regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
                                           entrySuccessors);
       for (RegionSuccessor &entrySuccessor : entrySuccessors) {
         // Wire the entry region's successor arguments with the initial
         // successor inputs.
         registerDependencies(
-            regionInterface.getEntrySuccessorOperands(entrySuccessor),
+            regionInterface.getEntrySuccessorOperands(
+                entrySuccessor.isParent()
+                    ? std::optional<unsigned>()
+                    : entrySuccessor.getSuccessor()->getRegionNumber()),
             entrySuccessor.getSuccessorInputs());
       }
 
@@ -115,16 +118,21 @@ void BufferViewFlowAnalysis::build(Operation *op) {
         // Iterate over all successor region entries that are reachable from the
         // current region.
         SmallVector<RegionSuccessor, 2> successorRegions;
-        regionInterface.getSuccessorRegions(region, successorRegions);
+        regionInterface.getSuccessorRegions(region.getRegionNumber(),
+                                            successorRegions);
         for (RegionSuccessor &successorRegion : successorRegions) {
+          // Determine the current region index (if any).
+          std::optional<unsigned> regionIndex;
+          Region *regionSuccessor = successorRegion.getSuccessor();
+          if (regionSuccessor)
+            regionIndex = regionSuccessor->getRegionNumber();
           // Iterate over all immediate terminator operations and wire the
           // successor inputs with the successor operands of each terminator.
           for (Block &block : region)
             if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
                     block.getTerminator()))
-              registerDependencies(
-                  terminator.getSuccessorOperands(successorRegion),
-                  successorRegion.getSuccessorInputs());
+              registerDependencies(terminator.getSuccessorOperands(regionIndex),
+                                   successorRegion.getSuccessorInputs());
         }
       }
 

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9c5c322e23692b..e1b8dd62450a77 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -455,8 +455,8 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 void AllocaScopeOp::getSuccessorRegions(
-    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
-  if (!point.isParent()) {
+    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
+  if (index) {
     regions.push_back(RegionSuccessor(getResults()));
     return;
   }

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index b573291f0460e6..63ce3b2a469627 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -266,9 +266,9 @@ void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
 /// correspond to a constant value for each operand, or null if that operand is
 /// not a constant.
 void ExecuteRegionOp::getSuccessorRegions(
-    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
   // If the predecessor is the ExecuteRegionOp, branch into the body.
-  if (point.isParent()) {
+  if (!index) {
     regions.push_back(RegionSuccessor(&getRegion()));
     return;
   }
@@ -282,8 +282,8 @@ void ExecuteRegionOp::getSuccessorRegions(
 //===----------------------------------------------------------------------===//
 
 MutableOperandRange
-ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
-  assert((point.isParent() || point == getParentOp().getAfter()) &&
+ConditionOp::getMutableSuccessorOperands(std::optional<unsigned> index) {
+  assert((!index || index == getParentOp().getAfter().getRegionNumber()) &&
          "condition op can only exit the loop or branch to the after"
          "region");
   // Pass all operands except the condition to the successor region.
@@ -553,7 +553,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
 /// Return operands used when entering the region at 'index'. These operands
 /// correspond to the loop iterator operands, i.e., those excluding the
 /// induction variable.
-OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+OperandRange ForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
   return getInitArgs();
 }
 
@@ -562,7 +562,7 @@ OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
 /// during the flow of control. `operands` is a set of optional attributes that
 /// correspond to a constant value for each operand, or null if that operand is
 /// not a constant.
-void ForOp::getSuccessorRegions(RegionBranchPoint point,
+void ForOp::getSuccessorRegions(std::optional<unsigned> index,
                                 SmallVectorImpl<RegionSuccessor> &regions) {
   // Both the operation itself and the region may be branching into the body or
   // back into the operation itself. It is possible for loop not to enter the
@@ -1731,7 +1731,7 @@ void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
 /// during the flow of control. `operands` is a set of optional attributes that
 /// correspond to a constant value for each operand, or null if that operand is
 /// not a constant.
-void ForallOp::getSuccessorRegions(RegionBranchPoint point,
+void ForallOp::getSuccessorRegions(std::optional<unsigned> index,
                                    SmallVectorImpl<RegionSuccessor> &regions) {
   // Both the operation itself and the region may be branching into the body or
   // back into the operation itself. It is possible for loop not to enter the
@@ -2011,10 +2011,10 @@ void IfOp::print(OpAsmPrinter &p) {
 /// during the flow of control. `operands` is a set of optional attributes that
 /// correspond to a constant value for each operand, or null if that operand is
 /// not a constant.
-void IfOp::getSuccessorRegions(RegionBranchPoint point,
+void IfOp::getSuccessorRegions(std::optional<unsigned> index,
                                SmallVectorImpl<RegionSuccessor> &regions) {
   // The `then` and the `else` region branch back to the parent operation.
-  if (!point.isParent()) {
+  if (index) {
     regions.push_back(RegionSuccessor(getResults()));
     return;
   }
@@ -3042,7 +3042,7 @@ void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
 /// correspond to a constant value for each operand, or null if that operand is
 /// not a constant.
 void ParallelOp::getSuccessorRegions(
-    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
   // Both the operation itself and the region may be branching into the body or
   // back into the operation itself. It is possible for loop not to enter the
   // body.
@@ -3169,8 +3169,8 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
     afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
 }
 
-OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
-  assert(point == getBefore() &&
+OperandRange WhileOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
+  assert(index && *index == 0 &&
          "WhileOp is expected to branch only to the first region");
 
   return getInits();
@@ -3192,18 +3192,17 @@ Block::BlockArgListType WhileOp::getAfterArguments() {
   return getAfterBody()->getArguments();
 }
 
-void WhileOp::getSuccessorRegions(RegionBranchPoint point,
+void WhileOp::getSuccessorRegions(std::optional<unsigned> index,
                                   SmallVectorImpl<RegionSuccessor> &regions) {
   // The parent op always branches to the condition region.
-  if (point.isParent()) {
+  if (!index) {
     regions.emplace_back(&getBefore(), getBefore().getArguments());
     return;
   }
 
-  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
-         "there are only two regions in a WhileOp");
+  assert(*index < 2 && "there are only two regions in a WhileOp");
   // The body region always branches back to the condition region.
-  if (point == getAfter()) {
+  if (*index == 1) {
     regions.emplace_back(&getBefore(), getBefore().getArguments());
     return;
   }
@@ -4024,9 +4023,10 @@ Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
 }
 
 void IndexSwitchOp::getSuccessorRegions(
-    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
+    std::optional<unsigned> index,
+    SmallVectorImpl<RegionSuccessor> &successors) {
   // All regions branch back to the parent op.
-  if (!point.isParent()) {
+  if (index) {
     successors.emplace_back(getResults());
     return;
   }

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 78b06d9ce033f8..c52a9c5004aaaf 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -335,11 +335,11 @@ void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 
 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
 void AssumingOp::getSuccessorRegions(
-    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
   // AssumingOp has unconditional control flow into the region and back to the
   // parent, so return the correct RegionSuccessor purely based on the index
   // being None or 0.
-  if (!point.isParent()) {
+  if (index) {
     regions.push_back(RegionSuccessor(getResults()));
     return;
   }

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 518bfc3931e8d3..7bc7272b054129 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -86,25 +86,23 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
 // AlternativesOp
 //===----------------------------------------------------------------------===//
 
-OperandRange
-transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
-  if (!point.isParent() && getOperation()->getNumOperands() == 1)
+OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
+    std::optional<unsigned> index) {
+  if (index && getOperation()->getNumOperands() == 1)
     return getOperation()->getOperands();
   return OperandRange(getOperation()->operand_end(),
                       getOperation()->operand_end());
 }
 
 void transform::AlternativesOp::getSuccessorRegions(
-    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
   for (Region &alternative : llvm::drop_begin(
-           getAlternatives(),
-           point.isParent() ? 0
-                            : point.getRegionOrNull()->getRegionNumber() + 1)) {
+           getAlternatives(), index.has_value() ? *index + 1 : 0)) {
     regions.emplace_back(&alternative, !getOperands().empty()
                                            ? alternative.getArguments()
                                            : Block::BlockArgListType());
   }
-  if (!point.isParent())
+  if (index.has_value())
     regions.emplace_back(getOperation()->getResults());
 }
 
@@ -1161,24 +1159,24 @@ void transform::ForeachOp::getEffects(
 }
 
 void transform::ForeachOp::getSuccessorRegions(
-    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
   Region *bodyRegion = &getBody();
-  if (point.isParent()) {
+  if (!index) {
     regions.emplace_back(bodyRegion, bodyRegion->getArguments());
     return;
   }
 
   // Branch back to the region or the parent.
-  assert(point == getBody() && "unexpected region index");
+  assert(*index == 0 && "unexpected region index");
   regions.emplace_back(bodyRegion, bodyRegion->getArguments());
   regions.emplace_back();
 }
 
 OperandRange
-transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+transform::ForeachOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
   // The iteration variable op handle is mapped to a subset (one op to be
   // precise) of the payload ops of the ForeachOp operand.
-  assert(point == getBody() && "unexpected region index");
+  assert(index && *index == 0 && "unexpected region index");
   return getOperation()->getOperands();
 }
 
@@ -2180,9 +2178,9 @@ void transform::SequenceOp::getEffects(
   getPotentialTopLevelEffects(effects);
 }
 
-OperandRange
-transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
-  assert(point == getBody() && "unexpected region index");
+OperandRange transform::SequenceOp::getEntrySuccessorOperands(
+    std::optional<unsigned> index) {
+  assert(index && *index == 0 && "unexpected region index");
   if (getOperation()->getNumOperands() > 0)
     return getOperation()->getOperands();
   return OperandRange(getOperation()->operand_end(),
@@ -2190,8 +2188,8 @@ transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
 }
 
 void transform::SequenceOp::getSuccessorRegions(
-    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
-  if (point.isParent()) {
+    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
+  if (!index) {
     Region *bodyRegion = &getBody();
     regions.emplace_back(bodyRegion, getNumOperands() != 0
                                          ? bodyRegion->getArguments()
@@ -2199,7 +2197,7 @@ void transform::SequenceOp::getSuccessorRegions(
     return;
   }
 
-  assert(point == getBody() && "unexpected region index");
+  assert(*index == 0 && "unexpected region index");
   regions.emplace_back(getOperation()->getResults());
 }
 

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 88bda3931a5a11..4e9364611b257d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5821,8 +5821,8 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
 }
 
 void WarpExecuteOnLane0Op::getSuccessorRegions(
-    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
-  if (!point.isParent()) {
+    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
+  if (index) {
     regions.push_back(RegionSuccessor(getResults()));
     return;
   }

diff  --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index b3166155e84f93..cc90da370de693 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -84,18 +84,18 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
 // RegionBranchOpInterface
 //===----------------------------------------------------------------------===//
 
-static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
-                                               RegionBranchPoint sourceNo,
-                                               RegionBranchPoint succRegionNo) {
+static InFlightDiagnostic &
+printRegionEdgeName(InFlightDiagnostic &diag, std::optional<unsigned> sourceNo,
+                    std::optional<unsigned> succRegionNo) {
   diag << "from ";
-  if (Region *region = sourceNo.getRegionOrNull())
-    diag << "Region #" << region->getRegionNumber();
+  if (sourceNo)
+    diag << "Region #" << sourceNo.value();
   else
     diag << "parent operands";
 
   diag << " to ";
-  if (Region *region = succRegionNo.getRegionOrNull())
-    diag << "Region #" << region->getRegionNumber();
+  if (succRegionNo)
+    diag << "Region #" << succRegionNo.value();
   else
     diag << "parent results";
   return diag;
@@ -107,24 +107,28 @@ static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
 /// inputs that flow from `sourceIndex' to the given region, or std::nullopt if
 /// the exact type match verification is not necessary (e.g., if the Op verifies
 /// the match itself).
-static LogicalResult
-verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
-                         function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
-                             getInputsTypesForRegion) {
+static LogicalResult verifyTypesAlongAllEdges(
+    Operation *op, std::optional<unsigned> sourceNo,
+    function_ref<FailureOr<TypeRange>(std::optional<unsigned>)>
+        getInputsTypesForRegion) {
   auto regionInterface = cast<RegionBranchOpInterface>(op);
 
   SmallVector<RegionSuccessor, 2> successors;
-  regionInterface.getSuccessorRegions(sourcePoint, successors);
+  regionInterface.getSuccessorRegions(sourceNo, successors);
 
   for (RegionSuccessor &succ : successors) {
-    FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
+    std::optional<unsigned> succRegionNo;
+    if (!succ.isParent())
+      succRegionNo = succ.getSuccessor()->getRegionNumber();
+
+    FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
     if (failed(sourceTypes))
       return failure();
 
     TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
     if (sourceTypes->size() != succInputsTypes.size()) {
       InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
-      return printRegionEdgeName(diag, sourcePoint, succ)
+      return printRegionEdgeName(diag, sourceNo, succRegionNo)
              << ": source has " << sourceTypes->size()
              << " operands, but target successor needs "
              << succInputsTypes.size();
@@ -136,7 +140,7 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
       Type inputType = std::get<1>(typesIdx.value());
       if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
         InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
-        return printRegionEdgeName(diag, sourcePoint, succ)
+        return printRegionEdgeName(diag, sourceNo, succRegionNo)
                << ": source type #" << typesIdx.index() << " " << sourceType
                << " should match input type #" << typesIdx.index() << " "
                << inputType;
@@ -150,13 +154,13 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
   auto regionInterface = cast<RegionBranchOpInterface>(op);
 
-  auto inputTypesFromParent = [&](RegionBranchPoint regionNo) -> TypeRange {
+  auto inputTypesFromParent =
+      [&](std::optional<unsigned> regionNo) -> TypeRange {
     return regionInterface.getEntrySuccessorOperands(regionNo).getTypes();
   };
 
   // Verify types along control flow edges originating from the parent.
-  if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(),
-                                      inputTypesFromParent)))
+  if (failed(verifyTypesAlongAllEdges(op, std::nullopt, inputTypesFromParent)))
     return failure();
 
   auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
@@ -172,7 +176,8 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
   };
 
   // Verify types along control flow edges originating from each region.
-  for (Region &region : op->getRegions()) {
+  for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
+    Region &region = op->getRegion(regionNo);
 
     // Since there can be multiple terminators implementing the
     // `RegionBranchTerminatorOpInterface`, all should have the same operand
@@ -190,7 +195,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
       continue;
 
     auto inputTypesForRegion =
-        [&](RegionBranchPoint succRegionNo) -> FailureOr<TypeRange> {
+        [&](std::optional<unsigned> succRegionNo) -> FailureOr<TypeRange> {
       std::optional<OperandRange> regionReturnOperands;
       for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
         auto terminatorOperands =
@@ -206,7 +211,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
         if (!areTypesCompatible(regionReturnOperands->getTypes(),
                                 terminatorOperands.getTypes())) {
           InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
-          return printRegionEdgeName(diag, region, succRegionNo)
+          return printRegionEdgeName(diag, regionNo, succRegionNo)
                  << " operands mismatch between return-like terminators";
         }
       }
@@ -215,7 +220,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
       return TypeRange(regionReturnOperands->getTypes());
     };
 
-    if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
+    if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesForRegion)))
       return failure();
   }
 
@@ -232,24 +237,24 @@ static bool isRegionReachable(Region *begin, Region *r) {
   visited[begin->getRegionNumber()] = true;
 
   // Retrieve all successors of the region and enqueue them in the worklist.
-  SmallVector<Region *> worklist;
-  auto enqueueAllSuccessors = [&](Region *region) {
+  SmallVector<unsigned> worklist;
+  auto enqueueAllSuccessors = [&](unsigned index) {
     SmallVector<RegionSuccessor> successors;
-    op.getSuccessorRegions(region, successors);
+    op.getSuccessorRegions(index, successors);
     for (RegionSuccessor successor : successors)
       if (!successor.isParent())
-        worklist.push_back(successor.getSuccessor());
+        worklist.push_back(successor.getSuccessor()->getRegionNumber());
   };
-  enqueueAllSuccessors(begin);
+  enqueueAllSuccessors(begin->getRegionNumber());
 
   // Process all regions in the worklist via DFS.
   while (!worklist.empty()) {
-    Region *nextRegion = worklist.pop_back_val();
-    if (nextRegion == r)
+    unsigned nextRegion = worklist.pop_back_val();
+    if (nextRegion == r->getRegionNumber())
       return true;
-    if (visited[nextRegion->getRegionNumber()])
+    if (visited[nextRegion])
       continue;
-    visited[nextRegion->getRegionNumber()] = true;
+    visited[nextRegion] = true;
     enqueueAllSuccessors(nextRegion);
   }
 

diff  --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 19a84db34dce02..ce19dc667f009d 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -316,11 +316,15 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
   // Return the successors of `region` if the latter is not null. Else return
   // the successors of `regionBranchOp`.
   auto getSuccessors = [&](Region *region = nullptr) {
-    auto point = region ? region : RegionBranchPoint::parent();
+    std::optional<unsigned> index =
+        region ? std::optional(region->getRegionNumber()) : std::nullopt;
     SmallVector<Attribute> operandAttributes(regionBranchOp->getNumOperands(),
                                              nullptr);
     SmallVector<RegionSuccessor> successors;
-    regionBranchOp.getSuccessorRegions(point, successors);
+    if (!index)
+      regionBranchOp.getEntrySuccessorRegions(operandAttributes, successors);
+    else
+      regionBranchOp.getSuccessorRegions(index, successors);
     return successors;
   };
 
@@ -329,10 +333,14 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
   // forwarded to `successor`.
   auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
                                     Operation *terminator = nullptr) {
+    Region *successorRegion = successor.getSuccessor();
+    std::optional<unsigned> index =
+        successorRegion ? std::optional(successorRegion->getRegionNumber())
+                        : std::nullopt;
     OperandRange operands =
         terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
-                         .getSuccessorOperands(successor)
-                   : regionBranchOp.getEntrySuccessorOperands(successor);
+                         .getSuccessorOperands(index)
+                   : regionBranchOp.getEntrySuccessorOperands(index);
     SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
     return opOperands;
   };

diff  --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
index 8bfd01d828060a..a33b523d5d192f 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
@@ -60,8 +60,8 @@ class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
                                     NextAccess *before) override;
 
   void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
-                                            RegionBranchPoint regionFrom,
-                                            RegionBranchPoint regionTo,
+                                            std::optional<unsigned> regionFrom,
+                                            std::optional<unsigned> regionTo,
                                             const NextAccess &after,
                                             NextAccess *before) override;
 
@@ -124,15 +124,15 @@ void NextAccessAnalysis::visitCallControlFlowTransfer(
 }
 
 void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
-    RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
-    RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) {
+    RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
+    std::optional<unsigned> regionTo, const NextAccess &after,
+    NextAccess *before) {
   auto testStoreWithARegion =
       dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
 
   if (testStoreWithARegion &&
-      ((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) ||
-       (regionFrom.isParent() &&
-        testStoreWithARegion.getStoreBeforeRegion()))) {
+      ((!regionTo && !testStoreWithARegion.getStoreBeforeRegion()) ||
+       (!regionFrom && testStoreWithARegion.getStoreBeforeRegion()))) {
     visitOperation(branch, static_cast<const NextAccess &>(after),
                    static_cast<NextAccess *>(before));
   } else {
@@ -219,7 +219,7 @@ struct TestNextAccessPass
 
       SmallVector<Attribute> entryPointNextAccess;
       SmallVector<RegionSuccessor> regionSuccessors;
-      iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors);
+      iface.getSuccessorRegions(std::nullopt, regionSuccessors);
       for (const RegionSuccessor &successor : regionSuccessors) {
         if (!successor.getSuccessor() || successor.getSuccessor()->empty())
           continue;

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 34ed7a1a66fe33..57a6ab387281dc 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -931,17 +931,17 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
                                 parser.getCurrentLocation(), result.operands);
 }
 
-OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
-  assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
-         "invalid region index");
+OperandRange
+RegionIfOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
+  assert(index && *index < 2 && "invalid region index");
   return getOperands();
 }
 
 void RegionIfOp::getSuccessorRegions(
-    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
   // We always branch to the join region.
-  if (!point.isParent()) {
-    if (point != getJoinRegion())
+  if (index.has_value()) {
+    if (index.value() < 2)
       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
     else
       regions.push_back(RegionSuccessor(getResults()));
@@ -964,11 +964,11 @@ void RegionIfOp::getRegionInvocationBounds(
 // AnyCondOp
 //===----------------------------------------------------------------------===//
 
-void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
+void AnyCondOp::getSuccessorRegions(std::optional<unsigned> index,
                                     SmallVectorImpl<RegionSuccessor> &regions) {
   // The parent op branches into the only region, and the region branches back
   // to the parent op.
-  if (point.isParent())
+  if (!index)
     regions.emplace_back(&getRegion());
   else
     regions.emplace_back(getResults());
@@ -985,16 +985,17 @@ void AnyCondOp::getRegionInvocationBounds(
 //===----------------------------------------------------------------------===//
 
 void LoopBlockOp::getSuccessorRegions(
-    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
   regions.emplace_back(&getBody(), getBody().getArguments());
-  if (point.isParent())
+  if (!index)
     return;
 
   regions.emplace_back((*this)->getResults());
 }
 
-OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
-  assert(point == getBody());
+OperandRange
+LoopBlockOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
+  assert(index == 0);
   return getInitMutable();
 }
 
@@ -1002,9 +1003,10 @@ OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
 // LoopBlockTerminatorOp
 //===----------------------------------------------------------------------===//
 
-MutableOperandRange
-LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
-  if (point.isParent())
+MutableOperandRange LoopBlockTerminatorOp::getMutableSuccessorOperands(
+    std::optional<unsigned> index) {
+  assert(!index || index == 0);
+  if (!index)
     return getExitArgMutable();
   return getNextIterArgMutable();
 }
@@ -1311,11 +1313,12 @@ MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
 }
 
 void TestStoreWithARegion::getSuccessorRegions(
-    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
-  if (point.isParent())
+    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
+  if (!index) {
     regions.emplace_back(&getBody(), getBody().front().getArguments());
-  else
+  } else {
     regions.emplace_back();
+  }
 }
 
 LogicalResult

diff  --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
index f1aae15393fd3f..a507baa6445d97 100644
--- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
@@ -37,7 +37,7 @@ struct MutuallyExclusiveRegionsOp
   }
 
   // Regions have no successors.
-  void getSuccessorRegions(RegionBranchPoint point,
+  void getSuccessorRegions(std::optional<unsigned> index,
                            SmallVectorImpl<RegionSuccessor> &regions) {}
 };
 
@@ -51,13 +51,14 @@ struct LoopRegionsOp
 
   static StringRef getOperationName() { return "cftest.loop_regions_op"; }
 
-  void getSuccessorRegions(RegionBranchPoint point,
+  void getSuccessorRegions(std::optional<unsigned> index,
                            SmallVectorImpl<RegionSuccessor> &regions) {
-    if (Region *region = point.getRegionOrNull()) {
-      if (point == (*this)->getRegion(1))
+    if (index) {
+      if (*index == 1)
         // This region also branches back to the parent.
         regions.push_back(RegionSuccessor());
-      regions.push_back(RegionSuccessor(region));
+      regions.push_back(
+          RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions)));
     }
   }
 };
@@ -73,11 +74,11 @@ struct DoubleLoopRegionsOp
     return "cftest.double_loop_regions_op";
   }
 
-  void getSuccessorRegions(RegionBranchPoint point,
+  void getSuccessorRegions(std::optional<unsigned> index,
                            SmallVectorImpl<RegionSuccessor> &regions) {
-    if (Region *region = point.getRegionOrNull()) {
+    if (index.has_value()) {
       regions.push_back(RegionSuccessor());
-      regions.push_back(RegionSuccessor(region));
+      regions.push_back(RegionSuccessor(&getOperation()->getRegion(*index)));
     }
   }
 };
@@ -91,9 +92,9 @@ struct SequentialRegionsOp
   static StringRef getOperationName() { return "cftest.sequential_regions_op"; }
 
   // Region 0 has Region 1 as a successor.
-  void getSuccessorRegions(RegionBranchPoint point,
+  void getSuccessorRegions(std::optional<unsigned> index,
                            SmallVectorImpl<RegionSuccessor> &regions) {
-    if (point == (*this)->getRegion(0)) {
+    if (index == 0u) {
       Operation *thisOp = this->getOperation();
       regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
     }


        


More information about the Mlir-commits mailing list