[Mlir-commits] [mlir] 4dd744a - Reland "[mlir] Use a type for representing branch points in `RegionBranchOpInterface`"

Markus Böck llvmlistbot at llvm.org
Wed Aug 30 00:32:01 PDT 2023


Author: Markus Böck
Date: 2023-08-30T09:31:54+02:00
New Revision: 4dd744ac9c0f772a61dd91c84bc14d17e69aec51

URL: https://github.com/llvm/llvm-project/commit/4dd744ac9c0f772a61dd91c84bc14d17e69aec51
DIFF: https://github.com/llvm/llvm-project/commit/4dd744ac9c0f772a61dd91c84bc14d17e69aec51.diff

LOG: Reland "[mlir] Use a type for representing branch points in `RegionBranchOpInterface`"

This reverts commit b26bb30b467b996c9786e3bd426c07684d84d406.

Added: 
    

Modified: 
    flang/lib/Optimizer/Dialect/FIROps.cpp
    mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
    mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
    mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
    mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
    mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Async/IR/Async.cpp
    mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
    mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
    mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Interfaces/ControlFlowInterfaces.cpp
    mlir/lib/Transforms/RemoveDeadValues.cpp
    mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index bbe06577c27e7b..80567b19f9fe5e 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3467,10 +3467,10 @@ void fir::IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
 /// return the successor regions. These are the regions that may be selected
 /// during the flow of control.
 void fir::IfOp::getSuccessorRegions(
-    std::optional<unsigned> index,
+    mlir::RegionBranchPoint point,
     llvm::SmallVectorImpl<mlir::RegionSuccessor> &regions) {
   // The `then` and the `else` region branch back to the parent operation.
-  if (index) {
+  if (!point.isParent()) {
     regions.push_back(mlir::RegionSuccessor(getResults()));
     return;
   }

diff  --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index a3a558f7705074..6a1335bab8bf6e 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -353,8 +353,8 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   /// any effect on the lattice that isn't already expressed by the interface
   /// itself.
   virtual void visitRegionBranchControlFlowTransfer(
-      RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
-      std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
+      RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
+      RegionBranchPoint regionTo, const AbstractDenseLattice &after,
       AbstractDenseLattice *before) {
     meet(before, after);
   }
@@ -382,7 +382,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   /// of the branch operation itself.
   void visitRegionBranchOperation(ProgramPoint point,
                                   RegionBranchOpInterface branch,
-                                  std::optional<unsigned> regionNo,
+                                  RegionBranchPoint branchPoint,
                                   AbstractDenseLattice *before);
 
   /// Visit an operation for which the data flow is described by the
@@ -472,9 +472,8 @@ class DenseBackwardDataFlowAnalysis
   /// nullptr`. The behavior can be further refined for specific pairs of "from"
   /// and "to" regions.
   virtual void visitRegionBranchControlFlowTransfer(
-      RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
-      std::optional<unsigned> regionTo, const LatticeT &after,
-      LatticeT *before) {
+      RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
+      RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
     AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
         branch, regionFrom, regionTo, after, before);
   }
@@ -508,8 +507,8 @@ class DenseBackwardDataFlowAnalysis
                                  static_cast<LatticeT *>(before));
   }
   void visitRegionBranchControlFlowTransfer(
-      RegionBranchOpInterface branch, std::optional<unsigned> regionForm,
-      std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
+      RegionBranchOpInterface branch, RegionBranchPoint regionForm,
+      RegionBranchPoint regionTo, const AbstractDenseLattice &after,
       AbstractDenseLattice *before) final {
     visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
                                          static_cast<const LatticeT &>(after),

diff  --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 13dacff3aa0422..5a9a36159b56c5 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -243,7 +243,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
   /// regions or the parent operation itself, and set either the argument or
   /// parent result lattices.
   void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
-                             std::optional<unsigned> successorIndex,
+                             RegionBranchPoint successor,
                              ArrayRef<AbstractSparseLattice *> lattices);
 };
 

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index bd81da41ed43cd..006aedced839f9 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -190,6 +190,68 @@ class RegionSuccessor {
   ValueRange inputs;
 };
 
+/// This class represents a point being branched from in the methods of the
+/// `RegionBranchOpInterface`.
+/// One can branch from one of two kinds of places:
+/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
+/// * A region within the parent operation.
+class RegionBranchPoint {
+public:
+  /// Returns an instance of `RegionBranchPoint` representing the parent
+  /// operation.
+  static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); }
+
+  /// Creates a `RegionBranchPoint` that branches from the given region.
+  /// The pointer must not be null.
+  RegionBranchPoint(Region *region) : maybeRegion(region) {
+    assert(region && "Region must not be null");
+  }
+
+  RegionBranchPoint(Region &region) : RegionBranchPoint(&region) {}
+
+  /// Explicitly stops users from constructing with `nullptr`.
+  RegionBranchPoint(std::nullptr_t) = delete;
+
+  /// Constructs a `RegionBranchPoint` from the the target of a
+  /// `RegionSuccessor` instance.
+  RegionBranchPoint(RegionSuccessor successor) {
+    if (successor.isParent())
+      maybeRegion = nullptr;
+    else
+      maybeRegion = successor.getSuccessor();
+  }
+
+  /// Assigns a region being branched from.
+  RegionBranchPoint &operator=(Region &region) {
+    maybeRegion = ®ion;
+    return *this;
+  }
+
+  /// Returns true if branching from the parent op.
+  bool isParent() const { return maybeRegion == nullptr; }
+
+  /// Returns the region if branching from a region.
+  /// A null pointer otherwise.
+  Region *getRegionOrNull() const { return maybeRegion; }
+
+  /// Returns true if the two branch points are equal.
+  friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
+    return lhs.maybeRegion == rhs.maybeRegion;
+  }
+
+private:
+  // Private constructor to encourage the use of `RegionBranchPoint::parent`.
+  constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
+
+  /// Internal encoding. Uses nullptr for representing branching from the parent
+  /// op and the region being branched from otherwise.
+  Region *maybeRegion;
+};
+
+inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
+  return !(lhs == rhs);
+}
+
 /// This class represents upper and lower bounds on the number of times a region
 /// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
 /// zero, but the upper bound may not be known.

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 132bd6d53d923a..e52636a5ac8fcc 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -133,14 +133,14 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
   let methods = [
     InterfaceMethod<[{
         Returns the operands of this operation used as the entry arguments when
-        entering the region at `index`, which was specified as a successor of
+        branching from `point`, which was specified as a successor of
         this operation by `getEntrySuccessorRegions`, or the operands forwarded
         to the operation's results when it branches back to itself. These operands
         should correspond 1-1 with the successor inputs specified in
         `getEntrySuccessorRegions`.
       }],
       "::mlir::OperandRange", "getEntrySuccessorOperands",
-      (ins "::std::optional<unsigned>":$index), [{}],
+      (ins "::mlir::RegionBranchPoint":$point), [{}],
       /*defaultImplementation=*/[{
         auto operandEnd = this->getOperation()->operand_end();
         return ::mlir::OperandRange(operandEnd, operandEnd);
@@ -162,22 +162,20 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
       (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
            "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
            [{}], [{
-        $_op.getSuccessorRegions(std::nullopt, regions);
+        $_op.getSuccessorRegions(mlir::RegionBranchPoint::parent(), regions);
       }]
     >,
     InterfaceMethod<[{
-        Returns the viable successors of a region at `index`, or the possible
-        successors when branching from the parent op if `index` is None. These
-        are the regions that may be selected during the flow of control. The
-        parent operation, i.e. a null `index`, may specify itself as successor,
-        which indicates that the control flow may not enter any region at all.
-        This method allows for describing which regions may be executed when
-        entering an operation, and which regions are executed after having
-        executed another region of the parent op. The successor region must be
-        non-empty.
+        Returns the viable successors of `point`. These are the regions that may
+        be selected during the flow of control. The parent operation, may
+        specify itself as successor, which indicates that the control flow may
+        not enter any region at all. This method allows for describing which
+        regions may be executed when entering an operation, and which regions
+        are executed after having executed another region of the parent op. The
+        successor region must be non-empty.
       }],
       "void", "getSuccessorRegions",
-      (ins "::std::optional<unsigned>":$index,
+      (ins "::mlir::RegionBranchPoint":$point,
            "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
     >,
     InterfaceMethod<[{
@@ -245,12 +243,10 @@ def RegionBranchTerminatorOpInterface :
   let methods = [
     InterfaceMethod<[{
         Returns a mutable range of operands that are semantically "returned" by
-        passing them to the region successor given by `index`.  If `index` is
-        None, this function returns the operands that are passed as a result to
-        the parent operation.
+        passing them to the region successor given by `point`.
       }],
       "::mlir::MutableOperandRange", "getMutableSuccessorOperands",
-      (ins "::std::optional<unsigned>":$index)
+      (ins "::mlir::RegionBranchPoint":$point)
     >,
     InterfaceMethod<[{
         Returns the viable region successors that are branched to after this
@@ -269,8 +265,7 @@ def RegionBranchTerminatorOpInterface :
       [{
         ::mlir::Operation *op = $_op;
         ::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
-          .getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
-            regions);
+          .getSuccessorRegions(op->getParentRegion(), regions);
       }]
     >,
   ];
@@ -290,8 +285,8 @@ def RegionBranchTerminatorOpInterface :
     // them to the region successor given by `index`.  If `index` is None, this
     // function returns the operands that are passed as a result to the parent
     // operation.
-    ::mlir::OperandRange getSuccessorOperands(std::optional<unsigned> index) {
-      return getMutableSuccessorOperands(index);
+    ::mlir::OperandRange getSuccessorOperands(::mlir::RegionBranchPoint point) {
+      return getMutableSuccessorOperands(point);
     }
   }];
 }
@@ -309,7 +304,7 @@ def ReturnLike : TraitList<[
         /*extraOpDeclaration=*/"",
         /*extraOpDefinition=*/[{
           ::mlir::MutableOperandRange $cppClass::getMutableSuccessorOperands(
-            ::std::optional<unsigned> index) {
+            ::mlir::RegionBranchPoint point) {
             return ::mlir::MutableOperandRange(*this);
           }
         }]

diff  --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index 970e68bc258649..ae2ba90412137c 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -45,9 +45,9 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
   // this region predecessor that correspond to the input values of `region`. If
   // an index could not be found, std::nullopt is returned instead.
   auto getOperandIndexIfPred =
-      [&](std::optional<unsigned> predIndex) -> std::optional<unsigned> {
+      [&](RegionBranchPoint pred) -> std::optional<unsigned> {
     SmallVector<RegionSuccessor, 2> successors;
-    branch.getSuccessorRegions(predIndex, successors);
+    branch.getSuccessorRegions(pred, successors);
     for (RegionSuccessor &successor : successors) {
       if (successor.getSuccessor() != region)
         continue;
@@ -75,28 +75,27 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
   };
 
   // Check branches from the parent operation.
-  std::optional<unsigned> regionIndex;
-  if (region) {
-    // Determine the actual region number from the passed region.
-    regionIndex = region->getRegionNumber();
-  }
+  auto branchPoint = RegionBranchPoint::parent();
+  if (region)
+    branchPoint = region;
+
   if (std::optional<unsigned> operandIndex =
-          getOperandIndexIfPred(/*predIndex=*/std::nullopt)) {
+          getOperandIndexIfPred(/*predIndex=*/RegionBranchPoint::parent())) {
     collectUnderlyingAddressValues(
-        branch.getEntrySuccessorOperands(regionIndex)[*operandIndex], maxDepth,
+        branch.getEntrySuccessorOperands(branchPoint)[*operandIndex], maxDepth,
         visited, output);
   }
   // Check branches from each child region.
   Operation *op = branch.getOperation();
-  for (int i = 0, e = op->getNumRegions(); i != e; ++i) {
-    if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(i)) {
-      for (Block &block : op->getRegion(i)) {
+  for (Region &region : op->getRegions()) {
+    if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(region)) {
+      for (Block &block : region) {
         // Try to determine possible region-branch successor operands for the
         // current region.
         if (auto term = dyn_cast<RegionBranchTerminatorOpInterface>(
                 block.getTerminator())) {
           collectUnderlyingAddressValues(
-              term.getSuccessorOperands(regionIndex)[*operandIndex], maxDepth,
+              term.getSuccessorOperands(branchPoint)[*operandIndex], maxDepth,
               visited, output);
         } else if (block.getNumSuccessors()) {
           // Otherwise, if this terminator may exit the region we can't make

diff  --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index c79a360e4c11bb..eab408cd5977c3 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -312,7 +312,8 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
 
   // Special cases where control flow may dictate data flow.
   if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
-    return visitRegionBranchOperation(op, branch, std::nullopt, before);
+    return visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(),
+                                      before);
   if (auto call = dyn_cast<CallOpInterface>(op))
     return visitCallOperation(call, before);
 
@@ -368,8 +369,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
     // If this block is exiting from an operation with region-based control
     // flow, propagate the lattice back along the control flow edge.
     if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
-      visitRegionBranchOperation(block, branch,
-                                 block->getParent()->getRegionNumber(), before);
+      visitRegionBranchOperation(block, branch, block->getParent(), before);
       return;
     }
 
@@ -396,13 +396,13 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
 
 void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
     ProgramPoint point, RegionBranchOpInterface branch,
-    std::optional<unsigned> regionNo, AbstractDenseLattice *before) {
+    RegionBranchPoint branchPoint, AbstractDenseLattice *before) {
 
   // The successors of the operation may be either the first operation of the
   // entry block of each possible successor region, or the next operation when
   // the branch is a successor of itself.
   SmallVector<RegionSuccessor> successors;
-  branch.getSuccessorRegions(regionNo, successors);
+  branch.getSuccessorRegions(branchPoint, successors);
   for (const RegionSuccessor &successor : successors) {
     const AbstractDenseLattice *after;
     if (successor.isParent() || successor.getSuccessor()->empty()) {
@@ -423,10 +423,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
       else
         after = getLatticeFor(point, &successorBlock->front());
     }
-    std::optional<unsigned> successorNo =
-        successor.isParent() ? std::optional<unsigned>()
-                             : successor.getSuccessor()->getRegionNumber();
-    visitRegionBranchControlFlowTransfer(branch, regionNo, successorNo, *after,
+
+    visitRegionBranchControlFlowTransfer(branch, branchPoint, successor, *after,
                                          before);
   }
 }

diff  --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 4708cdb042f126..02a0ce1bb29213 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -99,7 +99,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
   // The results of a region branch operation are determined by control-flow.
   if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
     return visitRegionSuccessors({branch}, branch,
-                                 /*successorIndex=*/std::nullopt,
+                                 /*successor=*/RegionBranchPoint::parent(),
                                  resultLattices);
   }
 
@@ -167,8 +167,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
 
     // Check if the lattices can be determined from region control flow.
     if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
-      return visitRegionSuccessors(
-          block, branch, block->getParent()->getRegionNumber(), argLattices);
+      return visitRegionSuccessors(block, branch, block->getParent(),
+                                   argLattices);
     }
 
     // Otherwise, we can't reason about the data-flow.
@@ -212,8 +212,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
 
 void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
     ProgramPoint point, RegionBranchOpInterface branch,
-    std::optional<unsigned> successorIndex,
-    ArrayRef<AbstractSparseLattice *> lattices) {
+    RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
   const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
   assert(predecessors->allPredecessorsKnown() &&
          "unexpected unresolved region successors");
@@ -224,11 +223,11 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
 
     // Check if the predecessor is the parent op.
     if (op == branch) {
-      operands = branch.getEntrySuccessorOperands(successorIndex);
+      operands = branch.getEntrySuccessorOperands(successor);
       // Otherwise, try to deduce the operands from a region return-like op.
     } else if (auto regionTerminator =
                    dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
-      operands = regionTerminator.getSuccessorOperands(successorIndex);
+      operands = regionTerminator.getSuccessorOperands(successor);
     }
 
     if (!operands) {
@@ -501,10 +500,7 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
   BitVector unaccounted(op->getNumOperands(), true);
 
   for (RegionSuccessor &successor : successors) {
-    Region *region = successor.getSuccessor();
-    OperandRange operands =
-        region ? branch.getEntrySuccessorOperands(region->getRegionNumber())
-               : branch.getEntrySuccessorOperands({});
+    OperandRange operands = branch.getEntrySuccessorOperands(successor);
     MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
     ValueRange inputs = successor.getSuccessorInputs();
     for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
@@ -538,9 +534,7 @@ void AbstractSparseBackwardDataFlowAnalysis::
 
   for (const RegionSuccessor &successor : successors) {
     ValueRange inputs = successor.getSuccessorInputs();
-    Region *region = successor.getSuccessor();
-    OperandRange operands = terminator.getSuccessorOperands(
-        region ? region->getRegionNumber() : std::optional<unsigned>{});
+    OperandRange operands = terminator.getSuccessorOperands(successor);
     MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
     for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
       meet(getLatticeElement(opOperand.get()),

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index bb4aaee21d019e..9d7b8f371a26c6 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2379,9 +2379,9 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
 /// correspond to the loop iterator operands, i.e., those excluding the
 /// induction variable. AffineForOp only has one region, so zero is the only
 /// valid value for `index`.
-OperandRange
-AffineForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
-  assert((!index || *index == 0) && "invalid region index");
+OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  assert((point.isParent() || point == getLoopBody()) &&
+         "invalid region point");
 
   // The initial operands map to the loop arguments after the induction
   // variable or are forwarded to the results when the trip count is zero.
@@ -2394,14 +2394,15 @@ AffineForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
 /// correspond to a constant value for each operand, or null if that operand is
 /// not a constant.
 void AffineForOp::getSuccessorRegions(
-    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
-  assert((!index.has_value() || index.value() == 0) && "expected loop region");
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  assert((point.isParent() || point == getLoopBody()) &&
+         "expected loop region");
   // The loop may typically branch back to its body or to the parent operation.
   // If the predecessor is the parent op and the trip count is known to be at
   // least one, branch into the body using the iterator arguments. And in cases
   // we know the trip count is zero, it can only branch back to its parent.
   std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
-  if (!index.has_value() && tripCount.has_value()) {
+  if (point.isParent() && tripCount.has_value()) {
     if (tripCount.value() > 0) {
       regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
       return;
@@ -2414,7 +2415,7 @@ void AffineForOp::getSuccessorRegions(
 
   // From the loop body, if the trip count is one, we can only branch back to
   // the parent.
-  if (index && tripCount && *tripCount == 1) {
+  if (!point.isParent() && tripCount && *tripCount == 1) {
     regions.push_back(RegionSuccessor(getResults()));
     return;
   }
@@ -2859,10 +2860,10 @@ struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
 /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
 /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
 void AffineIfOp::getSuccessorRegions(
-    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
   // If the predecessor is an AffineIfOp, then branching into both `then` and
   // `else` region is valid.
-  if (!index.has_value()) {
+  if (point.isParent()) {
     regions.reserve(2);
     regions.push_back(
         RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));

diff  --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 9b4fb81990c169..a05e02faf6d2f0 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -38,9 +38,8 @@ void AsyncDialect::initialize() {
 
 constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
 
-OperandRange
-ExecuteOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
-  assert(index && *index == 0 && "invalid region index");
+OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  assert(point == getBodyRegion() && "invalid region index");
   return getBodyOperands();
 }
 
@@ -53,11 +52,10 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
   return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
 }
 
-void ExecuteOp::getSuccessorRegions(std::optional<unsigned> index,
+void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
                                     SmallVectorImpl<RegionSuccessor> &regions) {
   // The `body` region branch back to the parent operation.
-  if (index) {
-    assert(*index == 0 && "invalid region index");
+  if (point == getBodyRegion()) {
     regions.push_back(RegionSuccessor(getBodyResults()));
     return;
   }

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index 582974873263d2..9a831e4c322ea0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -372,7 +372,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
     // parent operation. In this case, we have to introduce an additional clone
     // for buffer that is passed to the argument.
     SmallVector<RegionSuccessor, 2> successorRegions;
-    regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
+    regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
                                         successorRegions);
     auto *it =
         llvm::find_if(successorRegions, [&](RegionSuccessor &successorRegion) {
@@ -383,8 +383,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
 
     // Determine the actual operand to introduce a clone for and rewire the
     // operand to point to the clone instead.
-    auto operands =
-        regionInterface.getEntrySuccessorOperands(argRegion->getRegionNumber());
+    auto operands = regionInterface.getEntrySuccessorOperands(argRegion);
     size_t operandIndex =
         llvm::find(it->getSuccessorInputs(), blockArg).getIndex() +
         operands.getBeginOperandIndex();
@@ -432,8 +431,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
       // Query the regionInterface to get all successor regions of the current
       // one.
       SmallVector<RegionSuccessor, 2> successorRegions;
-      regionInterface.getSuccessorRegions(region.getRegionNumber(),
-                                          successorRegions);
+      regionInterface.getSuccessorRegions(region, successorRegions);
       // Try to find a matching region successor.
       RegionSuccessor *regionSuccessor =
           llvm::find_if(successorRegions, regionPredicate);
@@ -445,10 +443,6 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
           llvm::find(regionSuccessor->getSuccessorInputs(), argValue)
               .getIndex();
 
-      std::optional<unsigned> successorRegionNumber;
-      if (Region *successorRegion = regionSuccessor->getSuccessor())
-        successorRegionNumber = successorRegion->getRegionNumber();
-
       // Iterate over all immediate terminator operations to introduce
       // new buffer allocations. Thereby, the appropriate terminator operand
       // will be adjusted to point to the newly allocated buffer instead.
@@ -456,8 +450,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
               &region, [&](RegionBranchTerminatorOpInterface terminator) {
                 // Get the actual mutable operands for this terminator op.
                 auto terminatorOperands =
-                    terminator.getMutableSuccessorOperands(
-                        successorRegionNumber);
+                    terminator.getMutableSuccessorOperands(*regionSuccessor);
                 // Extract the source value from the current terminator.
                 // This conversion needs to exist on a separate line due to a
                 // bug in GCC conversion analysis.

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index f8231cac778af6..119801f9cc92f3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -123,7 +123,7 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
       return true;
     // Recurses into all region successors.
     SmallVector<RegionSuccessor, 2> successors;
-    regionInterface.getSuccessorRegions(current->getRegionNumber(), successors);
+    regionInterface.getSuccessorRegions(current, successors);
     for (RegionSuccessor &regionEntry : successors)
       if (recurse(regionEntry.getSuccessor()))
         return true;
@@ -132,7 +132,8 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
 
   // Start with all entry regions and test whether they induce a loop.
   SmallVector<RegionSuccessor, 2> successorRegions;
-  regionInterface.getSuccessorRegions(/*index=*/std::nullopt, successorRegions);
+  regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
+                                      successorRegions);
   for (RegionSuccessor &regionEntry : successorRegions) {
     if (recurse(regionEntry.getSuccessor()))
       return true;

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index d201e024380661..98a60a48763ab1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -100,16 +100,13 @@ void BufferViewFlowAnalysis::build(Operation *op) {
       // Query the RegionBranchOpInterface to find potential successor regions.
       // Extract all entry regions and wire all initial entry successor inputs.
       SmallVector<RegionSuccessor, 2> entrySuccessors;
-      regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
+      regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
                                           entrySuccessors);
       for (RegionSuccessor &entrySuccessor : entrySuccessors) {
         // Wire the entry region's successor arguments with the initial
         // successor inputs.
         registerDependencies(
-            regionInterface.getEntrySuccessorOperands(
-                entrySuccessor.isParent()
-                    ? std::optional<unsigned>()
-                    : entrySuccessor.getSuccessor()->getRegionNumber()),
+            regionInterface.getEntrySuccessorOperands(entrySuccessor),
             entrySuccessor.getSuccessorInputs());
       }
 
@@ -118,21 +115,16 @@ void BufferViewFlowAnalysis::build(Operation *op) {
         // Iterate over all successor region entries that are reachable from the
         // current region.
         SmallVector<RegionSuccessor, 2> successorRegions;
-        regionInterface.getSuccessorRegions(region.getRegionNumber(),
-                                            successorRegions);
+        regionInterface.getSuccessorRegions(region, successorRegions);
         for (RegionSuccessor &successorRegion : successorRegions) {
-          // Determine the current region index (if any).
-          std::optional<unsigned> regionIndex;
-          Region *regionSuccessor = successorRegion.getSuccessor();
-          if (regionSuccessor)
-            regionIndex = regionSuccessor->getRegionNumber();
           // Iterate over all immediate terminator operations and wire the
           // successor inputs with the successor operands of each terminator.
           for (Block &block : region)
             if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
                     block.getTerminator()))
-              registerDependencies(terminator.getSuccessorOperands(regionIndex),
-                                   successorRegion.getSuccessorInputs());
+              registerDependencies(
+                  terminator.getSuccessorOperands(successorRegion),
+                  successorRegion.getSuccessorInputs());
         }
       }
 

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e1b8dd62450a77..9c5c322e23692b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -455,8 +455,8 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 void AllocaScopeOp::getSuccessorRegions(
-    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
-  if (index) {
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  if (!point.isParent()) {
     regions.push_back(RegionSuccessor(getResults()));
     return;
   }

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 63ce3b2a469627..b573291f0460e6 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -266,9 +266,9 @@ void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
 /// correspond to a constant value for each operand, or null if that operand is
 /// not a constant.
 void ExecuteRegionOp::getSuccessorRegions(
-    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
   // If the predecessor is the ExecuteRegionOp, branch into the body.
-  if (!index) {
+  if (point.isParent()) {
     regions.push_back(RegionSuccessor(&getRegion()));
     return;
   }
@@ -282,8 +282,8 @@ void ExecuteRegionOp::getSuccessorRegions(
 //===----------------------------------------------------------------------===//
 
 MutableOperandRange
-ConditionOp::getMutableSuccessorOperands(std::optional<unsigned> index) {
-  assert((!index || index == getParentOp().getAfter().getRegionNumber()) &&
+ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
+  assert((point.isParent() || point == getParentOp().getAfter()) &&
          "condition op can only exit the loop or branch to the after"
          "region");
   // Pass all operands except the condition to the successor region.
@@ -553,7 +553,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
 /// Return operands used when entering the region at 'index'. These operands
 /// correspond to the loop iterator operands, i.e., those excluding the
 /// induction variable.
-OperandRange ForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
+OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
   return getInitArgs();
 }
 
@@ -562,7 +562,7 @@ OperandRange ForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
 /// during the flow of control. `operands` is a set of optional attributes that
 /// correspond to a constant value for each operand, or null if that operand is
 /// not a constant.
-void ForOp::getSuccessorRegions(std::optional<unsigned> index,
+void ForOp::getSuccessorRegions(RegionBranchPoint point,
                                 SmallVectorImpl<RegionSuccessor> &regions) {
   // Both the operation itself and the region may be branching into the body or
   // back into the operation itself. It is possible for loop not to enter the
@@ -1731,7 +1731,7 @@ void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
 /// during the flow of control. `operands` is a set of optional attributes that
 /// correspond to a constant value for each operand, or null if that operand is
 /// not a constant.
-void ForallOp::getSuccessorRegions(std::optional<unsigned> index,
+void ForallOp::getSuccessorRegions(RegionBranchPoint point,
                                    SmallVectorImpl<RegionSuccessor> &regions) {
   // Both the operation itself and the region may be branching into the body or
   // back into the operation itself. It is possible for loop not to enter the
@@ -2011,10 +2011,10 @@ void IfOp::print(OpAsmPrinter &p) {
 /// during the flow of control. `operands` is a set of optional attributes that
 /// correspond to a constant value for each operand, or null if that operand is
 /// not a constant.
-void IfOp::getSuccessorRegions(std::optional<unsigned> index,
+void IfOp::getSuccessorRegions(RegionBranchPoint point,
                                SmallVectorImpl<RegionSuccessor> &regions) {
   // The `then` and the `else` region branch back to the parent operation.
-  if (index) {
+  if (!point.isParent()) {
     regions.push_back(RegionSuccessor(getResults()));
     return;
   }
@@ -3042,7 +3042,7 @@ void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
 /// correspond to a constant value for each operand, or null if that operand is
 /// not a constant.
 void ParallelOp::getSuccessorRegions(
-    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
   // Both the operation itself and the region may be branching into the body or
   // back into the operation itself. It is possible for loop not to enter the
   // body.
@@ -3169,8 +3169,8 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
     afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
 }
 
-OperandRange WhileOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
-  assert(index && *index == 0 &&
+OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  assert(point == getBefore() &&
          "WhileOp is expected to branch only to the first region");
 
   return getInits();
@@ -3192,17 +3192,18 @@ Block::BlockArgListType WhileOp::getAfterArguments() {
   return getAfterBody()->getArguments();
 }
 
-void WhileOp::getSuccessorRegions(std::optional<unsigned> index,
+void WhileOp::getSuccessorRegions(RegionBranchPoint point,
                                   SmallVectorImpl<RegionSuccessor> &regions) {
   // The parent op always branches to the condition region.
-  if (!index) {
+  if (point.isParent()) {
     regions.emplace_back(&getBefore(), getBefore().getArguments());
     return;
   }
 
-  assert(*index < 2 && "there are only two regions in a WhileOp");
+  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
+         "there are only two regions in a WhileOp");
   // The body region always branches back to the condition region.
-  if (*index == 1) {
+  if (point == getAfter()) {
     regions.emplace_back(&getBefore(), getBefore().getArguments());
     return;
   }
@@ -4023,10 +4024,9 @@ Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
 }
 
 void IndexSwitchOp::getSuccessorRegions(
-    std::optional<unsigned> index,
-    SmallVectorImpl<RegionSuccessor> &successors) {
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
   // All regions branch back to the parent op.
-  if (index) {
+  if (!point.isParent()) {
     successors.emplace_back(getResults());
     return;
   }

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index c52a9c5004aaaf..78b06d9ce033f8 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -335,11 +335,11 @@ void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 
 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
 void AssumingOp::getSuccessorRegions(
-    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
   // AssumingOp has unconditional control flow into the region and back to the
   // parent, so return the correct RegionSuccessor purely based on the index
   // being None or 0.
-  if (index) {
+  if (!point.isParent()) {
     regions.push_back(RegionSuccessor(getResults()));
     return;
   }

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7bc7272b054129..518bfc3931e8d3 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -86,23 +86,25 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
 // AlternativesOp
 //===----------------------------------------------------------------------===//
 
-OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
-    std::optional<unsigned> index) {
-  if (index && getOperation()->getNumOperands() == 1)
+OperandRange
+transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  if (!point.isParent() && getOperation()->getNumOperands() == 1)
     return getOperation()->getOperands();
   return OperandRange(getOperation()->operand_end(),
                       getOperation()->operand_end());
 }
 
 void transform::AlternativesOp::getSuccessorRegions(
-    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
   for (Region &alternative : llvm::drop_begin(
-           getAlternatives(), index.has_value() ? *index + 1 : 0)) {
+           getAlternatives(),
+           point.isParent() ? 0
+                            : point.getRegionOrNull()->getRegionNumber() + 1)) {
     regions.emplace_back(&alternative, !getOperands().empty()
                                            ? alternative.getArguments()
                                            : Block::BlockArgListType());
   }
-  if (index.has_value())
+  if (!point.isParent())
     regions.emplace_back(getOperation()->getResults());
 }
 
@@ -1159,24 +1161,24 @@ void transform::ForeachOp::getEffects(
 }
 
 void transform::ForeachOp::getSuccessorRegions(
-    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
   Region *bodyRegion = &getBody();
-  if (!index) {
+  if (point.isParent()) {
     regions.emplace_back(bodyRegion, bodyRegion->getArguments());
     return;
   }
 
   // Branch back to the region or the parent.
-  assert(*index == 0 && "unexpected region index");
+  assert(point == getBody() && "unexpected region index");
   regions.emplace_back(bodyRegion, bodyRegion->getArguments());
   regions.emplace_back();
 }
 
 OperandRange
-transform::ForeachOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
+transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
   // The iteration variable op handle is mapped to a subset (one op to be
   // precise) of the payload ops of the ForeachOp operand.
-  assert(index && *index == 0 && "unexpected region index");
+  assert(point == getBody() && "unexpected region index");
   return getOperation()->getOperands();
 }
 
@@ -2178,9 +2180,9 @@ void transform::SequenceOp::getEffects(
   getPotentialTopLevelEffects(effects);
 }
 
-OperandRange transform::SequenceOp::getEntrySuccessorOperands(
-    std::optional<unsigned> index) {
-  assert(index && *index == 0 && "unexpected region index");
+OperandRange
+transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  assert(point == getBody() && "unexpected region index");
   if (getOperation()->getNumOperands() > 0)
     return getOperation()->getOperands();
   return OperandRange(getOperation()->operand_end(),
@@ -2188,8 +2190,8 @@ OperandRange transform::SequenceOp::getEntrySuccessorOperands(
 }
 
 void transform::SequenceOp::getSuccessorRegions(
-    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
-  if (!index) {
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  if (point.isParent()) {
     Region *bodyRegion = &getBody();
     regions.emplace_back(bodyRegion, getNumOperands() != 0
                                          ? bodyRegion->getArguments()
@@ -2197,7 +2199,7 @@ void transform::SequenceOp::getSuccessorRegions(
     return;
   }
 
-  assert(*index == 0 && "unexpected region index");
+  assert(point == getBody() && "unexpected region index");
   regions.emplace_back(getOperation()->getResults());
 }
 

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 4e9364611b257d..88bda3931a5a11 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5821,8 +5821,8 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
 }
 
 void WarpExecuteOnLane0Op::getSuccessorRegions(
-    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
-  if (index) {
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  if (!point.isParent()) {
     regions.push_back(RegionSuccessor(getResults()));
     return;
   }

diff  --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index cc90da370de693..b3166155e84f93 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -84,18 +84,18 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
 // RegionBranchOpInterface
 //===----------------------------------------------------------------------===//
 
-static InFlightDiagnostic &
-printRegionEdgeName(InFlightDiagnostic &diag, std::optional<unsigned> sourceNo,
-                    std::optional<unsigned> succRegionNo) {
+static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
+                                               RegionBranchPoint sourceNo,
+                                               RegionBranchPoint succRegionNo) {
   diag << "from ";
-  if (sourceNo)
-    diag << "Region #" << sourceNo.value();
+  if (Region *region = sourceNo.getRegionOrNull())
+    diag << "Region #" << region->getRegionNumber();
   else
     diag << "parent operands";
 
   diag << " to ";
-  if (succRegionNo)
-    diag << "Region #" << succRegionNo.value();
+  if (Region *region = succRegionNo.getRegionOrNull())
+    diag << "Region #" << region->getRegionNumber();
   else
     diag << "parent results";
   return diag;
@@ -107,28 +107,24 @@ printRegionEdgeName(InFlightDiagnostic &diag, std::optional<unsigned> sourceNo,
 /// inputs that flow from `sourceIndex' to the given region, or std::nullopt if
 /// the exact type match verification is not necessary (e.g., if the Op verifies
 /// the match itself).
-static LogicalResult verifyTypesAlongAllEdges(
-    Operation *op, std::optional<unsigned> sourceNo,
-    function_ref<FailureOr<TypeRange>(std::optional<unsigned>)>
-        getInputsTypesForRegion) {
+static LogicalResult
+verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
+                         function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
+                             getInputsTypesForRegion) {
   auto regionInterface = cast<RegionBranchOpInterface>(op);
 
   SmallVector<RegionSuccessor, 2> successors;
-  regionInterface.getSuccessorRegions(sourceNo, successors);
+  regionInterface.getSuccessorRegions(sourcePoint, successors);
 
   for (RegionSuccessor &succ : successors) {
-    std::optional<unsigned> succRegionNo;
-    if (!succ.isParent())
-      succRegionNo = succ.getSuccessor()->getRegionNumber();
-
-    FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
+    FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
     if (failed(sourceTypes))
       return failure();
 
     TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
     if (sourceTypes->size() != succInputsTypes.size()) {
       InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
-      return printRegionEdgeName(diag, sourceNo, succRegionNo)
+      return printRegionEdgeName(diag, sourcePoint, succ)
              << ": source has " << sourceTypes->size()
              << " operands, but target successor needs "
              << succInputsTypes.size();
@@ -140,7 +136,7 @@ static LogicalResult verifyTypesAlongAllEdges(
       Type inputType = std::get<1>(typesIdx.value());
       if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
         InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
-        return printRegionEdgeName(diag, sourceNo, succRegionNo)
+        return printRegionEdgeName(diag, sourcePoint, succ)
                << ": source type #" << typesIdx.index() << " " << sourceType
                << " should match input type #" << typesIdx.index() << " "
                << inputType;
@@ -154,13 +150,13 @@ static LogicalResult verifyTypesAlongAllEdges(
 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
   auto regionInterface = cast<RegionBranchOpInterface>(op);
 
-  auto inputTypesFromParent =
-      [&](std::optional<unsigned> regionNo) -> TypeRange {
+  auto inputTypesFromParent = [&](RegionBranchPoint regionNo) -> TypeRange {
     return regionInterface.getEntrySuccessorOperands(regionNo).getTypes();
   };
 
   // Verify types along control flow edges originating from the parent.
-  if (failed(verifyTypesAlongAllEdges(op, std::nullopt, inputTypesFromParent)))
+  if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(),
+                                      inputTypesFromParent)))
     return failure();
 
   auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
@@ -176,8 +172,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
   };
 
   // Verify types along control flow edges originating from each region.
-  for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
-    Region &region = op->getRegion(regionNo);
+  for (Region &region : op->getRegions()) {
 
     // Since there can be multiple terminators implementing the
     // `RegionBranchTerminatorOpInterface`, all should have the same operand
@@ -195,7 +190,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
       continue;
 
     auto inputTypesForRegion =
-        [&](std::optional<unsigned> succRegionNo) -> FailureOr<TypeRange> {
+        [&](RegionBranchPoint succRegionNo) -> FailureOr<TypeRange> {
       std::optional<OperandRange> regionReturnOperands;
       for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
         auto terminatorOperands =
@@ -211,7 +206,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
         if (!areTypesCompatible(regionReturnOperands->getTypes(),
                                 terminatorOperands.getTypes())) {
           InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
-          return printRegionEdgeName(diag, regionNo, succRegionNo)
+          return printRegionEdgeName(diag, region, succRegionNo)
                  << " operands mismatch between return-like terminators";
         }
       }
@@ -220,7 +215,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
       return TypeRange(regionReturnOperands->getTypes());
     };
 
-    if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesForRegion)))
+    if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
       return failure();
   }
 
@@ -237,24 +232,24 @@ static bool isRegionReachable(Region *begin, Region *r) {
   visited[begin->getRegionNumber()] = true;
 
   // Retrieve all successors of the region and enqueue them in the worklist.
-  SmallVector<unsigned> worklist;
-  auto enqueueAllSuccessors = [&](unsigned index) {
+  SmallVector<Region *> worklist;
+  auto enqueueAllSuccessors = [&](Region *region) {
     SmallVector<RegionSuccessor> successors;
-    op.getSuccessorRegions(index, successors);
+    op.getSuccessorRegions(region, successors);
     for (RegionSuccessor successor : successors)
       if (!successor.isParent())
-        worklist.push_back(successor.getSuccessor()->getRegionNumber());
+        worklist.push_back(successor.getSuccessor());
   };
-  enqueueAllSuccessors(begin->getRegionNumber());
+  enqueueAllSuccessors(begin);
 
   // Process all regions in the worklist via DFS.
   while (!worklist.empty()) {
-    unsigned nextRegion = worklist.pop_back_val();
-    if (nextRegion == r->getRegionNumber())
+    Region *nextRegion = worklist.pop_back_val();
+    if (nextRegion == r)
       return true;
-    if (visited[nextRegion])
+    if (visited[nextRegion->getRegionNumber()])
       continue;
-    visited[nextRegion] = true;
+    visited[nextRegion->getRegionNumber()] = true;
     enqueueAllSuccessors(nextRegion);
   }
 

diff  --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index ce19dc667f009d..19a84db34dce02 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -316,15 +316,11 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
   // Return the successors of `region` if the latter is not null. Else return
   // the successors of `regionBranchOp`.
   auto getSuccessors = [&](Region *region = nullptr) {
-    std::optional<unsigned> index =
-        region ? std::optional(region->getRegionNumber()) : std::nullopt;
+    auto point = region ? region : RegionBranchPoint::parent();
     SmallVector<Attribute> operandAttributes(regionBranchOp->getNumOperands(),
                                              nullptr);
     SmallVector<RegionSuccessor> successors;
-    if (!index)
-      regionBranchOp.getEntrySuccessorRegions(operandAttributes, successors);
-    else
-      regionBranchOp.getSuccessorRegions(index, successors);
+    regionBranchOp.getSuccessorRegions(point, successors);
     return successors;
   };
 
@@ -333,14 +329,10 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
   // forwarded to `successor`.
   auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
                                     Operation *terminator = nullptr) {
-    Region *successorRegion = successor.getSuccessor();
-    std::optional<unsigned> index =
-        successorRegion ? std::optional(successorRegion->getRegionNumber())
-                        : std::nullopt;
     OperandRange operands =
         terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
-                         .getSuccessorOperands(index)
-                   : regionBranchOp.getEntrySuccessorOperands(index);
+                         .getSuccessorOperands(successor)
+                   : regionBranchOp.getEntrySuccessorOperands(successor);
     SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
     return opOperands;
   };

diff  --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
index a33b523d5d192f..8bfd01d828060a 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
@@ -60,8 +60,8 @@ class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
                                     NextAccess *before) override;
 
   void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
-                                            std::optional<unsigned> regionFrom,
-                                            std::optional<unsigned> regionTo,
+                                            RegionBranchPoint regionFrom,
+                                            RegionBranchPoint regionTo,
                                             const NextAccess &after,
                                             NextAccess *before) override;
 
@@ -124,15 +124,15 @@ void NextAccessAnalysis::visitCallControlFlowTransfer(
 }
 
 void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
-    RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
-    std::optional<unsigned> regionTo, const NextAccess &after,
-    NextAccess *before) {
+    RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
+    RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) {
   auto testStoreWithARegion =
       dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
 
   if (testStoreWithARegion &&
-      ((!regionTo && !testStoreWithARegion.getStoreBeforeRegion()) ||
-       (!regionFrom && testStoreWithARegion.getStoreBeforeRegion()))) {
+      ((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) ||
+       (regionFrom.isParent() &&
+        testStoreWithARegion.getStoreBeforeRegion()))) {
     visitOperation(branch, static_cast<const NextAccess &>(after),
                    static_cast<NextAccess *>(before));
   } else {
@@ -219,7 +219,7 @@ struct TestNextAccessPass
 
       SmallVector<Attribute> entryPointNextAccess;
       SmallVector<RegionSuccessor> regionSuccessors;
-      iface.getSuccessorRegions(std::nullopt, regionSuccessors);
+      iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors);
       for (const RegionSuccessor &successor : regionSuccessors) {
         if (!successor.getSuccessor() || successor.getSuccessor()->empty())
           continue;

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 57a6ab387281dc..34ed7a1a66fe33 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -931,17 +931,17 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
                                 parser.getCurrentLocation(), result.operands);
 }
 
-OperandRange
-RegionIfOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
-  assert(index && *index < 2 && "invalid region index");
+OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
+         "invalid region index");
   return getOperands();
 }
 
 void RegionIfOp::getSuccessorRegions(
-    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
   // We always branch to the join region.
-  if (index.has_value()) {
-    if (index.value() < 2)
+  if (!point.isParent()) {
+    if (point != getJoinRegion())
       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
     else
       regions.push_back(RegionSuccessor(getResults()));
@@ -964,11 +964,11 @@ void RegionIfOp::getRegionInvocationBounds(
 // AnyCondOp
 //===----------------------------------------------------------------------===//
 
-void AnyCondOp::getSuccessorRegions(std::optional<unsigned> index,
+void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
                                     SmallVectorImpl<RegionSuccessor> &regions) {
   // The parent op branches into the only region, and the region branches back
   // to the parent op.
-  if (!index)
+  if (point.isParent())
     regions.emplace_back(&getRegion());
   else
     regions.emplace_back(getResults());
@@ -985,17 +985,16 @@ void AnyCondOp::getRegionInvocationBounds(
 //===----------------------------------------------------------------------===//
 
 void LoopBlockOp::getSuccessorRegions(
-    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
   regions.emplace_back(&getBody(), getBody().getArguments());
-  if (!index)
+  if (point.isParent())
     return;
 
   regions.emplace_back((*this)->getResults());
 }
 
-OperandRange
-LoopBlockOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
-  assert(index == 0);
+OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  assert(point == getBody());
   return getInitMutable();
 }
 
@@ -1003,10 +1002,9 @@ LoopBlockOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
 // LoopBlockTerminatorOp
 //===----------------------------------------------------------------------===//
 
-MutableOperandRange LoopBlockTerminatorOp::getMutableSuccessorOperands(
-    std::optional<unsigned> index) {
-  assert(!index || index == 0);
-  if (!index)
+MutableOperandRange
+LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
+  if (point.isParent())
     return getExitArgMutable();
   return getNextIterArgMutable();
 }
@@ -1313,12 +1311,11 @@ MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
 }
 
 void TestStoreWithARegion::getSuccessorRegions(
-    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
-  if (!index) {
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  if (point.isParent())
     regions.emplace_back(&getBody(), getBody().front().getArguments());
-  } else {
+  else
     regions.emplace_back();
-  }
 }
 
 LogicalResult

diff  --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
index a507baa6445d97..f1aae15393fd3f 100644
--- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
@@ -37,7 +37,7 @@ struct MutuallyExclusiveRegionsOp
   }
 
   // Regions have no successors.
-  void getSuccessorRegions(std::optional<unsigned> index,
+  void getSuccessorRegions(RegionBranchPoint point,
                            SmallVectorImpl<RegionSuccessor> &regions) {}
 };
 
@@ -51,14 +51,13 @@ struct LoopRegionsOp
 
   static StringRef getOperationName() { return "cftest.loop_regions_op"; }
 
-  void getSuccessorRegions(std::optional<unsigned> index,
+  void getSuccessorRegions(RegionBranchPoint point,
                            SmallVectorImpl<RegionSuccessor> &regions) {
-    if (index) {
-      if (*index == 1)
+    if (Region *region = point.getRegionOrNull()) {
+      if (point == (*this)->getRegion(1))
         // This region also branches back to the parent.
         regions.push_back(RegionSuccessor());
-      regions.push_back(
-          RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions)));
+      regions.push_back(RegionSuccessor(region));
     }
   }
 };
@@ -74,11 +73,11 @@ struct DoubleLoopRegionsOp
     return "cftest.double_loop_regions_op";
   }
 
-  void getSuccessorRegions(std::optional<unsigned> index,
+  void getSuccessorRegions(RegionBranchPoint point,
                            SmallVectorImpl<RegionSuccessor> &regions) {
-    if (index.has_value()) {
+    if (Region *region = point.getRegionOrNull()) {
       regions.push_back(RegionSuccessor());
-      regions.push_back(RegionSuccessor(&getOperation()->getRegion(*index)));
+      regions.push_back(RegionSuccessor(region));
     }
   }
 };
@@ -92,9 +91,9 @@ struct SequentialRegionsOp
   static StringRef getOperationName() { return "cftest.sequential_regions_op"; }
 
   // Region 0 has Region 1 as a successor.
-  void getSuccessorRegions(std::optional<unsigned> index,
+  void getSuccessorRegions(RegionBranchPoint point,
                            SmallVectorImpl<RegionSuccessor> &regions) {
-    if (index == 0u) {
+    if (point == (*this)->getRegion(0)) {
       Operation *thisOp = this->getOperation();
       regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
     }


        


More information about the Mlir-commits mailing list