[Mlir-commits] [mlir] 0752d98 - [mlir] Simplify BranchOpInterface by using MutableOperandRange

River Riddle llvmlistbot at llvm.org
Wed Apr 29 16:49:48 PDT 2020


Author: River Riddle
Date: 2020-04-29T16:48:15-07:00
New Revision: 0752d98ccf8771b41718170d46d11f4020b62818

URL: https://github.com/llvm/llvm-project/commit/0752d98ccf8771b41718170d46d11f4020b62818
DIFF: https://github.com/llvm/llvm-project/commit/0752d98ccf8771b41718170d46d11f4020b62818.diff

LOG: [mlir] Simplify BranchOpInterface by using MutableOperandRange

This range allows for performing many different operations on successor operands, including erasing/adding/setting. This removes the need for the explicit canEraseSuccessorOperand and eraseSuccessorOperand methods.

Differential Revision: https://reviews.llvm.org/D79077

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/FIROps.td
    flang/lib/Optimizer/Dialect/FIROps.cpp
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/IR/OperationSupport.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/IR/OperationSupport.cpp
    mlir/lib/Interfaces/ControlFlowInterfaces.cpp
    mlir/lib/Transforms/Utils/RegionUtils.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/tools/mlir-tblgen/OpInterfacesGen.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 46c39a1d498b..383256c3916f 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -585,6 +585,7 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
 
     llvm::Optional<llvm::ArrayRef<mlir::Value>> getSuccessorOperands(
         llvm::ArrayRef<mlir::Value> operands, unsigned cond);
+    using BranchOpInterfaceTrait::getSuccessorOperands;
 
     // Helper function to deal with Optional operand forms
     void printSuccessorAtIndex(mlir::OpAsmPrinter &p, unsigned i) {

diff  --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 1dd15fc959be..e2d94885e8fc 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -997,14 +997,26 @@ static constexpr llvm::StringRef getTargetOffsetAttr() {
   return "target_operand_offsets";
 }
 
-template <typename A>
+template <typename A, typename... AdditionalArgs>
 static A getSubOperands(unsigned pos, A allArgs,
-                        mlir::DenseIntElementsAttr ranges) {
+                        mlir::DenseIntElementsAttr ranges,
+                        AdditionalArgs &&... additionalArgs) {
   unsigned start = 0;
   for (unsigned i = 0; i < pos; ++i)
     start += (*(ranges.begin() + i)).getZExtValue();
-  unsigned end = start + (*(ranges.begin() + pos)).getZExtValue();
-  return {std::next(allArgs.begin(), start), std::next(allArgs.begin(), end)};
+  return allArgs.slice(start, (*(ranges.begin() + pos)).getZExtValue(),
+                       std::forward<AdditionalArgs>(additionalArgs)...);
+}
+
+static mlir::MutableOperandRange
+getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands,
+                            StringRef offsetAttr) {
+  Operation *owner = operands.getOwner();
+  NamedAttribute targetOffsetAttr =
+      *owner->getMutableAttrDict().getNamed(offsetAttr);
+  return getSubOperands(
+      pos, operands, targetOffsetAttr.second.cast<DenseIntElementsAttr>(),
+      mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr));
 }
 
 static unsigned denseElementsSize(mlir::DenseIntElementsAttr attr) {
@@ -1020,10 +1032,10 @@ fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
   return {};
 }
 
-llvm::Optional<mlir::OperandRange>
-fir::SelectOp::getSuccessorOperands(unsigned oper) {
-  auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
-  return {getSubOperands(oper, targetArgs(), a)};
+llvm::Optional<mlir::MutableOperandRange>
+fir::SelectOp::getMutableSuccessorOperands(unsigned oper) {
+  return ::getMutableSuccessorOperands(oper, targetArgsMutable(),
+                                       getTargetOffsetAttr());
 }
 
 llvm::Optional<llvm::ArrayRef<mlir::Value>>
@@ -1035,8 +1047,6 @@ fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
   return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
 }
 
-bool fir::SelectOp::canEraseSuccessorOperand() { return true; }
-
 unsigned fir::SelectOp::targetOffsetSize() {
   return denseElementsSize(
       getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()));
@@ -1061,10 +1071,10 @@ fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands,
   return {getSubOperands(cond, getSubOperands(1, operands, segments), a)};
 }
 
-llvm::Optional<mlir::OperandRange>
-fir::SelectCaseOp::getSuccessorOperands(unsigned oper) {
-  auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
-  return {getSubOperands(oper, targetArgs(), a)};
+llvm::Optional<mlir::MutableOperandRange>
+fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) {
+  return ::getMutableSuccessorOperands(oper, targetArgsMutable(),
+                                       getTargetOffsetAttr());
 }
 
 llvm::Optional<llvm::ArrayRef<mlir::Value>>
@@ -1076,8 +1086,6 @@ fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
   return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
 }
 
-bool fir::SelectCaseOp::canEraseSuccessorOperand() { return true; }
-
 // parser for fir.select_case Op
 static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser,
                                          mlir::OperationState &result) {
@@ -1254,10 +1262,10 @@ fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
   return {};
 }
 
-llvm::Optional<mlir::OperandRange>
-fir::SelectRankOp::getSuccessorOperands(unsigned oper) {
-  auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
-  return {getSubOperands(oper, targetArgs(), a)};
+llvm::Optional<mlir::MutableOperandRange>
+fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) {
+  return ::getMutableSuccessorOperands(oper, targetArgsMutable(),
+                                       getTargetOffsetAttr());
 }
 
 llvm::Optional<llvm::ArrayRef<mlir::Value>>
@@ -1269,8 +1277,6 @@ fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
   return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
 }
 
-bool fir::SelectRankOp::canEraseSuccessorOperand() { return true; }
-
 unsigned fir::SelectRankOp::targetOffsetSize() {
   return denseElementsSize(
       getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()));
@@ -1290,10 +1296,10 @@ fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
   return {};
 }
 
-llvm::Optional<mlir::OperandRange>
-fir::SelectTypeOp::getSuccessorOperands(unsigned oper) {
-  auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
-  return {getSubOperands(oper, targetArgs(), a)};
+llvm::Optional<mlir::MutableOperandRange>
+fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) {
+  return ::getMutableSuccessorOperands(oper, targetArgsMutable(),
+                                       getTargetOffsetAttr());
 }
 
 llvm::Optional<llvm::ArrayRef<mlir::Value>>
@@ -1305,8 +1311,6 @@ fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
   return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
 }
 
-bool fir::SelectTypeOp::canEraseSuccessorOperand() { return true; }
-
 static ParseResult parseSelectType(OpAsmParser &parser,
                                    OperationState &result) {
   mlir::OpAsmParser::OperandType selector;

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 48ed99051642..87f8e629a5c6 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1074,7 +1074,7 @@ def CondBranchOp : Std_Op<"cond_br",
 
     /// Erase the operand at 'index' from the true operand list.
     void eraseTrueOperand(unsigned index)  {
-      eraseSuccessorOperand(trueIndex, index);
+      trueDestOperandsMutable().erase(index);
     }
 
     // Accessors for operands to the 'false' destination.
@@ -1093,7 +1093,7 @@ def CondBranchOp : Std_Op<"cond_br",
 
     /// Erase the operand at 'index' from the false operand list.
     void eraseFalseOperand(unsigned index) {
-      eraseSuccessorOperand(falseIndex, index);
+      falseDestOperandsMutable().erase(index);
     }
 
   private:

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 2214b5db2f20..edfe89ad97f2 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -678,6 +678,10 @@ class MutableOperandRange {
                       ArrayRef<OperandSegment> operandSegments = llvm::None);
   MutableOperandRange(Operation *owner);
 
+  /// Slice this range into a sub range, with the additional operand segment.
+  MutableOperandRange slice(unsigned subStart, unsigned subLen,
+                            Optional<OperandSegment> segment = llvm::None);
+
   /// Append the given values to the range.
   void append(ValueRange values);
 
@@ -699,6 +703,9 @@ class MutableOperandRange {
   /// Allow implicit conversion to an OperandRange.
   operator OperandRange() const;
 
+  /// Returns the owning operation.
+  Operation *getOwner() const { return owner; }
+
 private:
   /// Update the length of this range to the one provided.
   void updateLength(unsigned newLength);

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index e22454538343..e18c46f745a2 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -24,11 +24,6 @@ class BranchOpInterface;
 //===----------------------------------------------------------------------===//
 
 namespace detail {
-/// Erase an operand from a branch operation that is used as a successor
-/// operand. `operandIndex` is the operand within `operands` to be erased.
-void eraseBranchSuccessorOperand(OperandRange operands, unsigned operandIndex,
-                                 Operation *op);
-
 /// Return the `BlockArgument` corresponding to operand `operandIndex` in some
 /// successor if `operandIndex` is within the range of `operands`, or None if
 /// `operandIndex` isn't a successor operand index.

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 5c02482394b7..591ca11830e9 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -27,29 +27,25 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
   }];
   let methods = [
     InterfaceMethod<[{
-        Returns a set of values that correspond to the arguments to the
+        Returns a mutable range of operands that correspond to the arguments of
         successor at the given index. Returns None if the operands to the
         successor are non-materialized values, i.e. they are internal to the
         operation.
       }],
-      "Optional<OperandRange>", "getSuccessorOperands", (ins "unsigned":$index)
+      "Optional<MutableOperandRange>", "getMutableSuccessorOperands",
+      (ins "unsigned":$index)
     >,
     InterfaceMethod<[{
-        Return true if this operation can erase an operand to a successor block.
-      }],
-      "bool", "canEraseSuccessorOperand"
-    >,
-    InterfaceMethod<[{
-        Erase the operand at `operandIndex` from the `index`-th successor. This
-        should only be called if `canEraseSuccessorOperand` returns true.
+        Returns a range of operands that correspond to the arguments of
+        successor at the given index. Returns None if the operands to the
+        successor are non-materialized values, i.e. they are internal to the
+        operation.
       }],
-      "void", "eraseSuccessorOperand",
-      (ins "unsigned":$index, "unsigned":$operandIndex), [{}],
-      /*defaultImplementation=*/[{
+      "Optional<OperandRange>", "getSuccessorOperands",
+      (ins "unsigned":$index), [{}], [{
         ConcreteOp *op = static_cast<ConcreteOp *>(this);
-        Optional<OperandRange> operands = op->getSuccessorOperands(index);
-        assert(operands && "unable to query operands for successor");
-        detail::eraseBranchSuccessorOperand(*operands, operandIndex, *op);
+        auto operands = op->getMutableSuccessorOperands(index);
+        return operands ? Optional<OperandRange>(*operands) : llvm::None;
       }]
     >,
     InterfaceMethod<[{

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 0a462d0239e3..5c112710ec55 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -160,24 +160,22 @@ static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
 // LLVM::BrOp
 //===----------------------------------------------------------------------===//
 
-Optional<OperandRange> BrOp::getSuccessorOperands(unsigned index) {
+Optional<MutableOperandRange>
+BrOp::getMutableSuccessorOperands(unsigned index) {
   assert(index == 0 && "invalid successor index");
-  return getOperands();
+  return destOperandsMutable();
 }
 
-bool BrOp::canEraseSuccessorOperand() { return true; }
-
 //===----------------------------------------------------------------------===//
 // LLVM::CondBrOp
 //===----------------------------------------------------------------------===//
 
-Optional<OperandRange> CondBrOp::getSuccessorOperands(unsigned index) {
+Optional<MutableOperandRange>
+CondBrOp::getMutableSuccessorOperands(unsigned index) {
   assert(index < getNumSuccessors() && "invalid successor index");
-  return index == 0 ? trueDestOperands() : falseDestOperands();
+  return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable();
 }
 
-bool CondBrOp::canEraseSuccessorOperand() { return true; }
-
 //===----------------------------------------------------------------------===//
 // Printing/parsing for LLVM::LoadOp.
 //===----------------------------------------------------------------------===//
@@ -257,13 +255,12 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
 /// LLVM::InvokeOp
 ///===---------------------------------------------------------------------===//
 
-Optional<OperandRange> InvokeOp::getSuccessorOperands(unsigned index) {
+Optional<MutableOperandRange>
+InvokeOp::getMutableSuccessorOperands(unsigned index) {
   assert(index < getNumSuccessors() && "invalid successor index");
-  return index == 0 ? normalDestOperands() : unwindDestOperands();
+  return index == 0 ? normalDestOperandsMutable() : unwindDestOperandsMutable();
 }
 
-bool InvokeOp::canEraseSuccessorOperand() { return true; }
-
 static LogicalResult verify(InvokeOp op) {
   if (op.getNumResults() > 1)
     return op.emitOpError("must have 0 or 1 result");

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index ed98d3745d6f..5d4e309a2e96 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -987,26 +987,23 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) {
 // spv.BranchOp
 //===----------------------------------------------------------------------===//
 
-Optional<OperandRange> spirv::BranchOp::getSuccessorOperands(unsigned index) {
+Optional<MutableOperandRange>
+spirv::BranchOp::getMutableSuccessorOperands(unsigned index) {
   assert(index == 0 && "invalid successor index");
-  return getOperands();
+  return targetOperandsMutable();
 }
 
-bool spirv::BranchOp::canEraseSuccessorOperand() { return true; }
-
 //===----------------------------------------------------------------------===//
 // spv.BranchConditionalOp
 //===----------------------------------------------------------------------===//
 
-Optional<OperandRange>
-spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
+Optional<MutableOperandRange>
+spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) {
   assert(index < 2 && "invalid successor index");
-  return index == kTrueIndex ? getTrueBlockArguments()
-                             : getFalseBlockArguments();
+  return index == kTrueIndex ? trueTargetOperandsMutable()
+                             : falseTargetOperandsMutable();
 }
 
-bool spirv::BranchConditionalOp::canEraseSuccessorOperand() { return true; }
-
 static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
                                             OperationState &state) {
   auto &builder = parser.getBuilder();

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 85efc4391234..8ef24e239152 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -677,13 +677,12 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
       context);
 }
 
-Optional<OperandRange> BranchOp::getSuccessorOperands(unsigned index) {
+Optional<MutableOperandRange>
+BranchOp::getMutableSuccessorOperands(unsigned index) {
   assert(index == 0 && "invalid successor index");
-  return getOperands();
+  return destOperandsMutable();
 }
 
-bool BranchOp::canEraseSuccessorOperand() { return true; }
-
 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }
 
 //===----------------------------------------------------------------------===//
@@ -1021,13 +1020,13 @@ void CondBranchOp::getCanonicalizationPatterns(
                  SimplifyCondBranchIdenticalSuccessors>(context);
 }
 
-Optional<OperandRange> CondBranchOp::getSuccessorOperands(unsigned index) {
+Optional<MutableOperandRange>
+CondBranchOp::getMutableSuccessorOperands(unsigned index) {
   assert(index < getNumSuccessors() && "invalid successor index");
-  return index == trueIndex ? getTrueOperands() : getFalseOperands();
+  return index == trueIndex ? trueDestOperandsMutable()
+                            : falseDestOperandsMutable();
 }
 
-bool CondBranchOp::canEraseSuccessorOperand() { return true; }
-
 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
   if (BoolAttr condAttr = operands.front().dyn_cast_or_null<BoolAttr>())
     return condAttr.getValue() ? trueDest() : falseDest();

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 83b4f0bf176e..a08762326143 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -287,6 +287,18 @@ MutableOperandRange::MutableOperandRange(
 MutableOperandRange::MutableOperandRange(Operation *owner)
     : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
 
+/// Slice this range into a sub range, with the additional operand segment.
+MutableOperandRange
+MutableOperandRange::slice(unsigned subStart, unsigned subLen,
+                           Optional<OperandSegment> segment) {
+  assert((subStart + subLen) <= length && "invalid sub-range");
+  MutableOperandRange subSlice(owner, start + subStart, subLen,
+                               operandSegments);
+  if (segment)
+    subSlice.operandSegments.push_back(*segment);
+  return subSlice;
+}
+
 /// Append the given values to the range.
 void MutableOperandRange::append(ValueRange values) {
   if (values.empty())

diff  --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 746dd402a35a..c1fa833f26da 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -21,39 +21,6 @@ using namespace mlir;
 // BranchOpInterface
 //===----------------------------------------------------------------------===//
 
-/// Erase an operand from a branch operation that is used as a successor
-/// operand. 'operandIndex' is the operand within 'operands' to be erased.
-void mlir::detail::eraseBranchSuccessorOperand(OperandRange operands,
-                                               unsigned operandIndex,
-                                               Operation *op) {
-  assert(operandIndex < operands.size() &&
-         "invalid index for successor operands");
-
-  // Erase the operand from the operation.
-  size_t fullOperandIndex = operands.getBeginOperandIndex() + operandIndex;
-  op->eraseOperand(fullOperandIndex);
-
-  // If this operation has an OperandSegmentSizeAttr, keep it up to date.
-  auto operandSegmentAttr =
-      op->getAttrOfType<DenseElementsAttr>("operand_segment_sizes");
-  if (!operandSegmentAttr)
-    return;
-
-  // Find the segment containing the full operand index and decrement it.
-  // TODO: This seems like a general utility that could be added somewhere.
-  SmallVector<int32_t, 4> values(operandSegmentAttr.getValues<int32_t>());
-  unsigned currentSize = 0;
-  for (unsigned i = 0, e = values.size(); i != e; ++i) {
-    currentSize += values[i];
-    if (fullOperandIndex < currentSize) {
-      --values[i];
-      break;
-    }
-  }
-  op->setAttr("operand_segment_sizes",
-              DenseIntElementsAttr::get(operandSegmentAttr.getType(), values));
-}
-
 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
 /// successor if 'operandIndex' is within the range of 'operands', or None if
 /// `operandIndex` isn't a successor operand index.

diff  --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 162091cd53de..7a00032650b2 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -209,7 +209,7 @@ static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
 
   // Check to see if we can reason about the successor operands and mutate them.
   BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
-  if (!branchInterface || !branchInterface.canEraseSuccessorOperand()) {
+  if (!branchInterface) {
     for (Block *successor : op->getSuccessors())
       for (BlockArgument arg : successor->getArguments())
         liveMap.setProvedLive(arg);
@@ -219,7 +219,7 @@ static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
   // If we can't reason about the operands to a successor, conservatively mark
   // all arguments as live.
   for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
-    if (!branchInterface.getSuccessorOperands(i))
+    if (!branchInterface.getMutableSuccessorOperands(i))
       for (BlockArgument arg : op->getSuccessor(i)->getArguments())
         liveMap.setProvedLive(arg);
   }
@@ -278,7 +278,8 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator,
     // since it will promote later operands of the terminator being erased
     // first, reducing the quadratic-ness.
     unsigned succ = succE - succI - 1;
-    Optional<OperandRange> succOperands = branchOp.getSuccessorOperands(succ);
+    Optional<MutableOperandRange> succOperands =
+        branchOp.getMutableSuccessorOperands(succ);
     if (!succOperands)
       continue;
     Block *successor = terminator->getSuccessor(succ);
@@ -288,7 +289,7 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator,
       // shifting later args when earlier args are erased.
       unsigned arg = argE - argI - 1;
       if (!liveMap.wasProvenLive(successor->getArgument(arg)))
-        branchOp.eraseSuccessorOperand(succ, arg);
+        succOperands->erase(arg);
     }
   }
 }

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 4c67310e3705..1a40f9989eae 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -167,13 +167,12 @@ TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
 // TestBranchOp
 //===----------------------------------------------------------------------===//
 
-Optional<OperandRange> TestBranchOp::getSuccessorOperands(unsigned index) {
+Optional<MutableOperandRange>
+TestBranchOp::getMutableSuccessorOperands(unsigned index) {
   assert(index == 0 && "invalid successor index");
-  return getOperands();
+  return targetOperandsMutable();
 }
 
-bool TestBranchOp::canEraseSuccessorOperand() { return true; }
-
 //===----------------------------------------------------------------------===//
 // Test IsolatedRegionOp - parse passthrough region arguments.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 12ba8d43c9c1..ae86f713c462 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -146,7 +146,7 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
                           StringRef interfaceName,
                           StringRef interfaceTraitsName) {
   os << "  template <typename ConcreteOp>\n  "
-     << llvm::formatv("struct Trait : public OpInterface<{0},"
+     << llvm::formatv("struct {0}Trait : public OpInterface<{0},"
                       " detail::{1}>::Trait<ConcreteOp> {{\n",
                       interfaceName, interfaceTraitsName);
 
@@ -171,13 +171,17 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
   tblgen::FmtContext traitCtx;
   traitCtx.withOp("op");
   if (auto verify = interface.getVerify()) {
-    os << "  static LogicalResult verifyTrait(Operation* op) {\n"
+    os << "    static LogicalResult verifyTrait(Operation* op) {\n"
        << std::string(tblgen::tgfmt(*verify, &traitCtx)) << "\n  }\n";
   }
   if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
     os << extraTraitDecls << "\n";
 
   os << "  };\n";
+
+  // Emit a utility using directive for the trait class.
+  os << "    template <typename ConcreteOp>\n    "
+     << llvm::formatv("using Trait = {0}Trait<ConcreteOp>;\n", interfaceName);
 }
 
 static void emitInterfaceDecl(OpInterface &interface, raw_ostream &os) {


        


More information about the Mlir-commits mailing list