[Mlir-commits] [mlir] 10ae8ae - [mlir][NFC] Make `ReturnLike` trait imply `RegionBranchTerminatorOpInterface`

Markus Böck llvmlistbot at llvm.org
Tue Aug 8 13:12:11 PDT 2023


Author: Markus Böck
Date: 2023-08-08T22:11:39+02:00
New Revision: 10ae8ae8375d6b69064204338a33500917749da9

URL: https://github.com/llvm/llvm-project/commit/10ae8ae8375d6b69064204338a33500917749da9
DIFF: https://github.com/llvm/llvm-project/commit/10ae8ae8375d6b69064204338a33500917749da9.diff

LOG: [mlir][NFC] Make `ReturnLike` trait imply `RegionBranchTerminatorOpInterface`

This implication was already done de-facto and there were plenty of users and wrapper functions specifically used to handle the "return-like or RegionBranchTerminatorOpInterface" case. These simply existed due to up until recently missing features in ODS.

With the new capabilities of traits, we can make `ReturnLike` imply `RegionBranchTerminatorOpInterface` and auto generate proper definitions for its methods.
Various occurrences and wrapper methods used for `isa<RegionBranchTerminatorOpInterface>() || hasTrait<ReturnLike>()` have all been removed.

Differential Revision: https://reviews.llvm.org/D157402

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
    mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
    mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
    mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
    mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
    mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
    mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
    mlir/lib/Interfaces/ControlFlowInterfaces.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index c0cb09ddfd8c2a..b8514481a044c0 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -397,13 +397,13 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   void visitRegionSuccessors(RegionBranchOpInterface branch,
                              ArrayRef<AbstractSparseLattice *> operands);
 
-  /// Visit a terminator (an op implementing `RegionBranchTerminatorOpInterface`
-  /// or a return-like op) to compute the lattice values of its operands, given
-  /// its parent op `branch`. The lattice value of an operand is determined
-  /// based on the corresponding arguments in `terminator`'s region
-  /// successor(s).
-  void visitRegionSuccessorsFromTerminator(Operation *terminator,
-                                           RegionBranchOpInterface branch);
+  /// Visit a `RegionBranchTerminatorOpInterface` to compute the lattice values
+  /// of its operands, given its parent op `branch`. The lattice value of an
+  /// operand is determined based on the corresponding arguments in
+  /// `terminator`'s region successor(s).
+  void visitRegionSuccessorsFromTerminator(
+      RegionBranchTerminatorOpInterface terminator,
+      RegionBranchOpInterface branch);
 
   /// Get the lattice element for a value, and also set up
   /// dependencies so that the analysis on the given ProgramPoint is re-invoked

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index e3c262da17039a..9dab4358f4f3a6 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -229,32 +229,6 @@ Region *getEnclosingRepetitiveRegion(Operation *op);
 /// exists.
 Region *getEnclosingRepetitiveRegion(Value value);
 
-//===----------------------------------------------------------------------===//
-// RegionBranchTerminatorOpInterface
-//===----------------------------------------------------------------------===//
-
-/// Returns true if the given operation is either annotated with the
-/// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`.
-bool isRegionReturnLike(Operation *operation);
-
-/// Returns the mutable operands that are passed to the region with the given
-/// `regionIndex`. If the operation does not implement the
-/// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the
-/// result will be `std::nullopt`. In all other cases, the resulting
-/// `OperandRange` represents all operands that are passed to the specified
-/// successor region. If `regionIndex` is `std::nullopt`, all operands that are
-/// passed to the parent operation will be returned.
-std::optional<MutableOperandRange>
-getMutableRegionBranchSuccessorOperands(Operation *operation,
-                                        std::optional<unsigned> regionIndex);
-
-/// Returns the read only operands that are passed to the region with the given
-/// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more
-/// information.
-std::optional<OperandRange>
-getRegionBranchSuccessorOperands(Operation *operation,
-                                 std::optional<unsigned> regionIndex);
-
 //===----------------------------------------------------------------------===//
 // ControlFlow Traits
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 1b75645261a8af..b0cea5c5565c4e 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -272,6 +272,19 @@ def RegionBranchTerminatorOpInterface :
 //===----------------------------------------------------------------------===//
 
 // Op is "return-like".
-def ReturnLike : NativeOpTrait<"ReturnLike">;
+def ReturnLike : TraitList<[
+    DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>,
+    NativeOpTrait<
+        /*name=*/"ReturnLike",
+        /*traits=*/[],
+        /*extraOpDeclaration=*/"",
+        /*extraOpDefinition=*/[{
+          ::mlir::MutableOperandRange $cppClass::getMutableSuccessorOperands(
+            ::std::optional<unsigned> index) {
+            return ::mlir::MutableOperandRange(*this);
+          }
+        }]
+    >
+]>;
 
 #endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES

diff  --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index f205fabbac7c7a..7d893f7b918ab4 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -91,15 +91,14 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
   for (int i = 0, e = op->getNumRegions(); i != e; ++i) {
     if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(i)) {
       for (Block &block : op->getRegion(i)) {
-        Operation *term = block.getTerminator();
         // Try to determine possible region-branch successor operands for the
         // current region.
-        auto successorOperands =
-            getRegionBranchSuccessorOperands(term, regionIndex);
-        if (successorOperands) {
-          collectUnderlyingAddressValues((*successorOperands)[*operandIndex],
-                                         maxDepth, visited, output);
-        } else if (term->getNumSuccessors()) {
+        if (auto term = dyn_cast<RegionBranchTerminatorOpInterface>(
+                block.getTerminator())) {
+          collectUnderlyingAddressValues(
+              term.getSuccessorOperands(regionIndex)[*operandIndex], maxDepth,
+              visited, output);
+        } else if (block.getNumSuccessors()) {
           // Otherwise, if this terminator may exit the region we can't make
           // any assumptions about which values get passed.
           output.push_back(inputValue);

diff  --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 6b51d55088ae99..c79a360e4c11bb 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -337,9 +337,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
     // There may be a weird case where a terminator may be transferring control
     // either to the parent or to another block, so exit blocks and successors
     // are not mutually exclusive.
-    Operation *terminator = b->getTerminator();
-    return terminator && (terminator->hasTrait<OpTrait::ReturnLike>() ||
-                          isa<RegionBranchTerminatorOpInterface>(terminator));
+    return isa_and_nonnull<RegionBranchTerminatorOpInterface>(
+        b->getTerminator());
   };
   if (isExitBlock(block)) {
     // If this block is exiting from a callable, the successors of exiting from

diff  --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index 968b06572633e6..0bcfb332207742 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -93,11 +93,9 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
   // `BranchOpInterface`, `RegionBranchTerminatorOpInterface` or return-like op.
   Operation *op = operand.getOwner();
   assert((isa<RegionBranchOpInterface>(op) || isa<BranchOpInterface>(op) ||
-          isa<RegionBranchTerminatorOpInterface>(op) ||
-          op->hasTrait<OpTrait::ReturnLike>()) &&
+          isa<RegionBranchTerminatorOpInterface>(op)) &&
          "expected the op to be `RegionBranchOpInterface`, "
-         "`BranchOpInterface`, `RegionBranchTerminatorOpInterface`, or "
-         "return-like");
+         "`BranchOpInterface` or `RegionBranchTerminatorOpInterface`");
 
   // The lattices of the non-forwarded branch operands don't get updated like
   // the forwarded branch operands or the non-branch operands. Thus they need
@@ -161,11 +159,10 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
   visitOperation(op, operandLiveness, resultsLiveness);
 
   // We also visit the parent op with the parent's results and this operand if
-  // `op` is a `RegionBranchTerminatorOpInterface` or return-like because its
-  // non-forwarded operand depends on not only its memory effects/results but
-  // also on those of its parent's.
-  if (!isa<RegionBranchTerminatorOpInterface>(op) &&
-      !op->hasTrait<OpTrait::ReturnLike>())
+  // `op` is a `RegionBranchTerminatorOpInterface` because its non-forwarded
+  // operand depends on not only its memory effects/results but also on those of
+  // its parent's.
+  if (!isa<RegionBranchTerminatorOpInterface>(op))
     return;
   Operation *parentOp = op->getParentOp();
   SmallVector<const Liveness *, 4> parentResultsLiveness;

diff  --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 3007b3826e439f..abe754a60cfbda 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -226,9 +226,9 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
     if (op == branch) {
       operands = branch.getSuccessorEntryOperands(successorIndex);
       // Otherwise, try to deduce the operands from a region return-like op.
-    } else {
-      if (isRegionReturnLike(op))
-        operands = getRegionBranchSuccessorOperands(op, successorIndex);
+    } else if (auto regionTerminator =
+                   dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
+      operands = regionTerminator.getSuccessorOperands(successorIndex);
     }
 
     if (!operands) {
@@ -439,10 +439,9 @@ void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
   // successor's input. There are two types of successor operands: the operands
   // of this op itself and the operands of the terminators of the regions of
   // this op.
-  if (isa<RegionBranchTerminatorOpInterface>(op) ||
-      op->hasTrait<OpTrait::ReturnLike>()) {
+  if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
     if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
-      visitRegionSuccessorsFromTerminator(op, branch);
+      visitRegionSuccessorsFromTerminator(terminator, branch);
       return;
     }
   }
@@ -506,12 +505,11 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
 }
 
 void AbstractSparseBackwardDataFlowAnalysis::
-    visitRegionSuccessorsFromTerminator(Operation *terminator,
-                                        RegionBranchOpInterface branch) {
-  assert(isa<RegionBranchTerminatorOpInterface>(terminator) ||
-         terminator->hasTrait<OpTrait::ReturnLike>() &&
-             "expected a `RegionBranchTerminatorOpInterface` op or a "
-             "return-like op");
+    visitRegionSuccessorsFromTerminator(
+        RegionBranchTerminatorOpInterface terminator,
+        RegionBranchOpInterface branch) {
+  assert(isa<RegionBranchTerminatorOpInterface>(terminator) &&
+         "expected a `RegionBranchTerminatorOpInterface` op");
   assert(terminator->getParentOp() == branch.getOperation() &&
          "expected `branch` to be the parent op of `terminator`");
 
@@ -527,10 +525,8 @@ void AbstractSparseBackwardDataFlowAnalysis::
   for (const RegionSuccessor &successor : successors) {
     ValueRange inputs = successor.getSuccessorInputs();
     Region *region = successor.getSuccessor();
-    OperandRange operands =
-        region ? *getRegionBranchSuccessorOperands(terminator,
-                                                   region->getRegionNumber())
-               : *getRegionBranchSuccessorOperands(terminator, {});
+    OperandRange operands = terminator.getSuccessorOperands(
+        region ? region->getRegionNumber() : std::optional<unsigned>{});
     MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
     for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
       meet(getLatticeElement(opOperand.get()),

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 5eb345df5fe236..701ac5cdc07d9f 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -690,7 +690,7 @@ bool AnalysisState::isTensorYielded(Value tensor) const {
       return true;
 
     // Check if the op is returning/yielding.
-    if (isRegionReturnLike(op))
+    if (isa<RegionBranchTerminatorOpInterface>(op))
       return true;
 
     // Add all aliasing OpResults to the worklist.

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index b813b2425bdd50..40d40cf46f0b6f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -70,14 +70,15 @@ using namespace mlir;
 using namespace mlir::bufferization;
 
 /// Walks over all immediate return-like terminators in the given region.
-static LogicalResult
-walkReturnOperations(Region *region,
-                     llvm::function_ref<LogicalResult(Operation *)> func) {
+static LogicalResult walkReturnOperations(
+    Region *region,
+    llvm::function_ref<LogicalResult(RegionBranchTerminatorOpInterface)> func) {
   for (Block &block : *region) {
     Operation *terminator = block.getTerminator();
     // Skip non region-return-like terminators.
-    if (isRegionReturnLike(terminator)) {
-      if (failed(func(terminator)))
+    if (auto regionTerminator =
+            dyn_cast<RegionBranchTerminatorOpInterface>(terminator)) {
+      if (failed(func(regionTerminator)))
         return failure();
     }
   }
@@ -447,23 +448,25 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
       // Iterate over all immediate terminator operations to introduce
       // new buffer allocations. Thereby, the appropriate terminator operand
       // will be adjusted to point to the newly allocated buffer instead.
-      if (failed(walkReturnOperations(&region, [&](Operation *terminator) {
-            // Get the actual mutable operands for this terminator op.
-            auto terminatorOperands = *getMutableRegionBranchSuccessorOperands(
-                terminator, region.getRegionNumber());
-            // Extract the source value from the current terminator.
-            // This conversion needs to exist on a separate line due to a bug in
-            // GCC conversion analysis.
-            OperandRange immutableTerminatorOperands = terminatorOperands;
-            Value sourceValue = immutableTerminatorOperands[operandIndex];
-            // Create a new clone at the current location of the terminator.
-            auto clone = introduceCloneBuffers(sourceValue, terminator);
-            if (failed(clone))
-              return failure();
-            // Wire clone and terminator operand.
-            terminatorOperands.slice(operandIndex, 1).assign(*clone);
-            return success();
-          })))
+      if (failed(walkReturnOperations(
+              &region, [&](RegionBranchTerminatorOpInterface terminator) {
+                // Get the actual mutable operands for this terminator op.
+                auto terminatorOperands =
+                    terminator.getMutableSuccessorOperands(
+                        region.getRegionNumber());
+                // Extract the source value from the current terminator.
+                // This conversion needs to exist on a separate line due to a
+                // bug in GCC conversion analysis.
+                OperandRange immutableTerminatorOperands = terminatorOperands;
+                Value sourceValue = immutableTerminatorOperands[operandIndex];
+                // Create a new clone at the current location of the terminator.
+                auto clone = introduceCloneBuffers(sourceValue, terminator);
+                if (failed(clone))
+                  return failure();
+                // Wire clone and terminator operand.
+                terminatorOperands.slice(operandIndex, 1).assign(*clone);
+                return success();
+              })))
         return failure();
     }
     return success();

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
index 2a2baf39235250..b4f6a5f61fba20 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
@@ -75,7 +75,8 @@ leavesAllocationScope(Region *parentRegion,
       // If there is at least one alias that leaves the parent region, we know
       // that this alias escapes the whole region and hence the associated
       // allocation leaves allocation scope.
-      if (isRegionReturnLike(use) && use->getParentRegion() == parentRegion)
+      if (isa<RegionBranchTerminatorOpInterface>(use) &&
+          use->getParentRegion() == parentRegion)
         return true;
     }
   }

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 003fc62f657857..39b00dd2956bc0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -128,14 +128,11 @@ void BufferViewFlowAnalysis::build(Operation *op) {
             regionIndex = regionSuccessor->getRegionNumber();
           // Iterate over all immediate terminator operations and wire the
           // successor inputs with the successor operands of each terminator.
-          for (Block &block : region) {
-            auto successorOperands = getRegionBranchSuccessorOperands(
-                block.getTerminator(), regionIndex);
-            if (successorOperands) {
-              registerDependencies(*successorOperands,
+          for (Block &block : region)
+            if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
+                    block.getTerminator()))
+              registerDependencies(terminator.getSuccessorOperands(regionIndex),
                                    successorRegion.getSuccessorInputs());
-            }
-          }
         }
       }
 

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index eddb0423241ae6..0d898a4aeba461 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -183,7 +183,8 @@ void OneShotAnalysisState::createAliasInfoEntry(Value v) {
 // the IR.
 void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
   op->walk([&](Operation *returnOp) {
-    if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp))
+    if (!isa<RegionBranchTerminatorOpInterface>(returnOp) ||
+        !getOptions().isOpAllowed(returnOp))
       return WalkResult::advance();
 
     for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
@@ -1059,7 +1060,7 @@ static LogicalResult assertNoAllocsReturned(Operation *op,
   LogicalResult status = success();
   DominanceInfo domInfo(op);
   op->walk([&](Operation *returnOp) {
-    if (!isRegionReturnLike(returnOp) ||
+    if (!isa<RegionBranchTerminatorOpInterface>(returnOp) ||
         !state.getOptions().isOpAllowed(returnOp))
       return WalkResult::advance();
 

diff  --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index deec6058c1a0cd..e4eefaa450b89a 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -183,12 +183,13 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
 
     std::optional<OperandRange> regionReturnOperands;
     for (Block &block : region) {
-      Operation *terminator = block.getTerminator();
-      auto terminatorOperands =
-          getRegionBranchSuccessorOperands(terminator, regionNo);
-      if (!terminatorOperands)
+      auto terminator =
+          dyn_cast<RegionBranchTerminatorOpInterface>(block.getTerminator());
+      if (!terminator)
         continue;
 
+      OperandRange terminatorOperands =
+          terminator.getSuccessorOperands(regionNo);
       if (!regionReturnOperands) {
         regionReturnOperands = terminatorOperands;
         continue;
@@ -197,7 +198,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
       // Found more than one ReturnLike terminator. Make sure the operand types
       // match with the first one.
       if (!areTypesCompatible(regionReturnOperands->getTypes(),
-                              terminatorOperands->getTypes()))
+                              terminatorOperands.getTypes()))
         return op->emitOpError("Region #")
                << regionNo
                << " operands mismatch between return-like terminators";
@@ -316,7 +317,7 @@ void RegionBranchOpInterface::getSuccessorRegions(
     // exiting terminator in the region.
     for (Block &block : getOperation()->getRegion(*index)) {
       Operation *terminator = block.getTerminator();
-      if (getRegionBranchSuccessorOperands(terminator, *index)) {
+      if (isa<RegionBranchTerminatorOpInterface>(terminator)) {
         numInputs = terminator->getNumOperands();
         break;
       }
@@ -350,51 +351,3 @@ Region *mlir::getEnclosingRepetitiveRegion(Value value) {
   }
   return nullptr;
 }
-
-//===----------------------------------------------------------------------===//
-// RegionBranchTerminatorOpInterface
-//===----------------------------------------------------------------------===//
-
-/// Returns true if the given operation is either annotated with the
-/// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`.
-bool mlir::isRegionReturnLike(Operation *operation) {
-  return dyn_cast<RegionBranchTerminatorOpInterface>(operation) ||
-         operation->hasTrait<OpTrait::ReturnLike>();
-}
-
-/// Returns the mutable operands that are passed to the region with the given
-/// `regionIndex`. If the operation does not implement the
-/// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the
-/// result will be `std::nullopt`. In all other cases, the resulting
-/// `OperandRange` represents all operands that are passed to the specified
-/// successor region. If `regionIndex` is `std::nullopt`, all operands that are
-/// passed to the parent operation will be returned.
-std::optional<MutableOperandRange>
-mlir::getMutableRegionBranchSuccessorOperands(
-    Operation *operation, std::optional<unsigned> regionIndex) {
-  // Try to query a RegionBranchTerminatorOpInterface to determine
-  // all successor operands that will be passed to the successor
-  // input arguments.
-  if (auto regionTerminatorInterface =
-          dyn_cast<RegionBranchTerminatorOpInterface>(operation))
-    return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex);
-
-  // TODO: The ReturnLike trait should imply a default implementation of the
-  // RegionBranchTerminatorOpInterface. This would make this code significantly
-  // easier. Furthermore, this may even make this function obsolete.
-  if (operation->hasTrait<OpTrait::ReturnLike>())
-    return MutableOperandRange(operation);
-  return std::nullopt;
-}
-
-/// Returns the read only operands that are passed to the region with the given
-/// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more
-/// information.
-std::optional<OperandRange>
-mlir::getRegionBranchSuccessorOperands(Operation *operation,
-                                       std::optional<unsigned> regionIndex) {
-  auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex);
-  if (range)
-    return range->operator OperandRange();
-  return std::nullopt;
-}


        


More information about the Mlir-commits mailing list