[Mlir-commits] [mlir] 0c789db - [mlir] Add support for operation-produced successor arguments in BranchOpInterface

Markus Böck llvmlistbot at llvm.org
Thu Apr 7 23:28:36 PDT 2022


Author: Markus Böck
Date: 2022-04-08T08:28:16+02:00
New Revision: 0c789db541c236abf47265331a2f2b0945aa7b93

URL: https://github.com/llvm/llvm-project/commit/0c789db541c236abf47265331a2f2b0945aa7b93
DIFF: https://github.com/llvm/llvm-project/commit/0c789db541c236abf47265331a2f2b0945aa7b93.diff

LOG: [mlir] Add support for operation-produced successor arguments in BranchOpInterface

This patch revamps the BranchOpInterface a bit and allows a proper implementation of what was previously `getMutableSuccessorOperands` for operations, which internally produce arguments to some of the block arguments. A motivating example for this would be an invoke op with a error handling path:
```
invoke %function(%0)
  label ^success ^error(%1 : i32)

^error(%e: !error, %arg0 : i32):
  ...
```
The advantages of this are that any users of `BranchOpInterface` can still argue over remaining block argument operands (such as `%1` in the example above), as well as make use of the modifying capabilities to add more operands, erase an operand etc.

The way this patch implements that functionality is via a new class called `SuccessorOperands`, which is now returned by `getSuccessorOperands`. It basically contains an `unsigned` denoting how many operator produced operands exist, as well as a `MutableOperandRange`, which are the usual forwarded operands we are used to. The produced operands are assumed to the first few block arguments, followed by the forwarded operands afterwards. The role of `SuccessorOperands` is to provide various utility functions to modify and query the successor arguments from a `BranchOpInterface`.

Differential Revision: https://reviews.llvm.org/D123062

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/FIROps.td
    flang/lib/Optimizer/Dialect/FIROps.cpp
    mlir/include/mlir/IR/OperationSupport.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
    mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
    mlir/lib/Analysis/BufferViewFlowAnalysis.cpp
    mlir/lib/Analysis/DataFlowAnalysis.cpp
    mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
    mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
    mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Interfaces/ControlFlowInterfaces.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/lib/Transforms/Utils/RegionUtils.cpp
    mlir/test/Transforms/sccp.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 6eb0fdfea669d..e0c09396dc7c9 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -489,16 +489,12 @@ class fir_SwitchTerminatorOp<string mnemonic, list<Trait> traits = []> :
         llvm::ArrayRef<mlir::Value> operands, unsigned cond);
     llvm::Optional<mlir::ValueRange> getSuccessorOperands(
         mlir::ValueRange operands, unsigned cond);
-    using BranchOpInterfaceTrait::getSuccessorOperands;
 
     // Helper function to deal with Optional operand forms
     void printSuccessorAtIndex(mlir::OpAsmPrinter &p, unsigned i) {
       auto *succ = getSuccessor(i);
       auto ops = getSuccessorOperands(i);
-      if (ops.hasValue())
-        p.printSuccessorAndUseList(succ, ops.getValue());
-      else
-        p.printSuccessor(succ);
+      p.printSuccessorAndUseList(succ, ops.getForwardedOperands());
     }
 
     mlir::ArrayAttr getCases() {

diff  --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 6dfc5a90d8fcb..1b48a56adbcaf 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -2401,10 +2401,9 @@ fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
   return {};
 }
 
-llvm::Optional<mlir::MutableOperandRange>
-fir::SelectOp::getMutableSuccessorOperands(unsigned oper) {
-  return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(),
-                                       getTargetOffsetAttr());
+mlir::SuccessorOperands fir::SelectOp::getSuccessorOperands(unsigned oper) {
+  return mlir::SuccessorOperands(::getMutableSuccessorOperands(
+      oper, getTargetArgsMutable(), getTargetOffsetAttr()));
 }
 
 llvm::Optional<llvm::ArrayRef<mlir::Value>>
@@ -2462,10 +2461,9 @@ fir::SelectCaseOp::getCompareOperands(mlir::ValueRange operands,
   return {getSubOperands(cond, getSubOperands(1, operands, segments), a)};
 }
 
-llvm::Optional<mlir::MutableOperandRange>
-fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) {
-  return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(),
-                                       getTargetOffsetAttr());
+mlir::SuccessorOperands fir::SelectCaseOp::getSuccessorOperands(unsigned oper) {
+  return mlir::SuccessorOperands(::getMutableSuccessorOperands(
+      oper, getTargetArgsMutable(), getTargetOffsetAttr()));
 }
 
 llvm::Optional<llvm::ArrayRef<mlir::Value>>
@@ -2734,10 +2732,9 @@ fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
   return {};
 }
 
-llvm::Optional<mlir::MutableOperandRange>
-fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) {
-  return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(),
-                                       getTargetOffsetAttr());
+mlir::SuccessorOperands fir::SelectRankOp::getSuccessorOperands(unsigned oper) {
+  return mlir::SuccessorOperands(::getMutableSuccessorOperands(
+      oper, getTargetArgsMutable(), getTargetOffsetAttr()));
 }
 
 llvm::Optional<llvm::ArrayRef<mlir::Value>>
@@ -2779,10 +2776,9 @@ fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
   return {};
 }
 
-llvm::Optional<mlir::MutableOperandRange>
-fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) {
-  return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(),
-                                       getTargetOffsetAttr());
+mlir::SuccessorOperands fir::SelectTypeOp::getSuccessorOperands(unsigned oper) {
+  return mlir::SuccessorOperands(::getMutableSuccessorOperands(
+      oper, getTargetArgsMutable(), getTargetOffsetAttr()));
 }
 
 llvm::Optional<llvm::ArrayRef<mlir::Value>>

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 3707747b5bff0..22cf6fb2423a7 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -907,6 +907,11 @@ class MutableOperandRange {
   /// elements attribute, which contains the sizes of the sub ranges.
   MutableOperandRangeRange split(NamedAttribute segmentSizes) const;
 
+  /// Returns the value at the given index.
+  Value operator[](unsigned index) const {
+    return static_cast<OperandRange>(*this)[index];
+  }
+
 private:
   /// Update the length of this range to the one provided.
   void updateLength(unsigned newLength);

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 1e8f7b54c474a..3fc73de2c0cd6 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -20,6 +20,106 @@ namespace mlir {
 class BranchOpInterface;
 class RegionBranchOpInterface;
 
+/// This class models how operands are forwarded to block arguments in control
+/// flow. It consists of a number, denoting how many of the successors block
+/// arguments are produced by the operation, followed by a range of operands
+/// that are forwarded. The produced operands are passed to the first few
+/// block arguments of the successor, followed by the forwarded operands.
+/// It is unsupported to pass them in a 
diff erent order.
+///
+/// An example operation with both of these concepts would be a branch-on-error
+/// operation, that internally produces an error object on the error path:
+///
+///   invoke %function(%0)
+///     label ^success ^error(%1 : i32)
+///
+///     ^error(%e: !error, %arg0 : i32):
+///       ...
+///
+/// This operation would return an instance of SuccessorOperands with a produced
+/// operand count of 1 (mapped to %e in the successor) and a forwarded
+/// operands range consisting of %1 in the example above (mapped to %arg0 in the
+/// successor).
+class SuccessorOperands {
+public:
+  /// Constructs a SuccessorOperands with no produced operands that simply
+  /// forwards operands to the successor.
+  explicit SuccessorOperands(MutableOperandRange forwardedOperands);
+
+  /// Constructs a SuccessorOperands with the given amount of produced operands
+  /// and forwarded operands.
+  SuccessorOperands(unsigned producedOperandCount,
+                    MutableOperandRange forwardedOperands);
+
+  /// Returns the amount of operands passed to the successor. This consists both
+  /// of produced operands by the operation as well as forwarded ones.
+  unsigned size() const {
+    return producedOperandCount + forwardedOperands.size();
+  }
+
+  /// Returns true if there are no successor operands.
+  bool empty() const { return size() == 0; }
+
+  /// Returns the amount of operands that are produced internally by the
+  /// operation. These are passed to the first few block arguments.
+  unsigned getProducedOperandCount() const { return producedOperandCount; }
+
+  /// Returns true if the successor operand denoted by `index` is produced by
+  /// the operation.
+  bool isOperandProduced(unsigned index) const {
+    return index < producedOperandCount;
+  }
+
+  /// Returns the Value that is passed to the successors block argument denoted
+  /// by `index`. If it is produced by the operation, no such value exists and
+  /// a null Value is returned.
+  Value operator[](unsigned index) const {
+    if (isOperandProduced(index))
+      return Value();
+    return forwardedOperands[index - producedOperandCount];
+  }
+
+  /// Get the range of operands that are simply forwarded to the successor.
+  OperandRange getForwardedOperands() const { return forwardedOperands; }
+
+  /// Get a slice of the operands forwarded to the successor. The given range
+  /// must not contain any operands produced by the operation.
+  MutableOperandRange slice(unsigned subStart, unsigned subLen) const {
+    assert(!isOperandProduced(subStart) &&
+           "can't slice operands produced by the operation");
+    return forwardedOperands.slice(subStart - producedOperandCount, subLen);
+  }
+
+  /// Erase operands forwarded to the successor. The given range must
+  /// not contain any operands produced by the operation.
+  void erase(unsigned subStart, unsigned subLen = 1) {
+    assert(!isOperandProduced(subStart) &&
+           "can't erase operands produced by the operation");
+    forwardedOperands.erase(subStart - producedOperandCount, subLen);
+  }
+
+  /// Add new operands that are forwarded to the successor.
+  void append(ValueRange valueRange) { forwardedOperands.append(valueRange); }
+
+  /// Gets the index of the forwarded operand within the operation which maps
+  /// to the block argument denoted by `blockArgumentIndex`. The block argument
+  /// must be mapped to a forwarded operand.
+  unsigned getOperandIndex(unsigned blockArgumentIndex) const {
+    assert(!isOperandProduced(blockArgumentIndex) &&
+           "can't map operand produced by the operation");
+    return static_cast<mlir::OperandRange>(forwardedOperands)
+               .getBeginOperandIndex() +
+           (blockArgumentIndex - producedOperandCount);
+  }
+
+private:
+  /// Amount of operands that are produced internally within the operation and
+  /// passed to the first few block arguments.
+  unsigned producedOperandCount;
+  /// Range of operands that are forwarded to the remaining block arguments.
+  MutableOperandRange forwardedOperands;
+};
+
 //===----------------------------------------------------------------------===//
 // BranchOpInterface
 //===----------------------------------------------------------------------===//
@@ -29,12 +129,12 @@ namespace detail {
 /// successor if `operandIndex` is within the range of `operands`, or None if
 /// `operandIndex` isn't a successor operand index.
 Optional<BlockArgument>
-getBranchSuccessorArgument(Optional<OperandRange> operands,
+getBranchSuccessorArgument(const SuccessorOperands &operands,
                            unsigned operandIndex, Block *successor);
 
 /// Verify that the given operands match those of the given successor block.
 LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
-                                            Optional<OperandRange> operands);
+                                            const SuccessorOperands &operands);
 } // namespace detail
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 9d7b43b5e4a47..ac805ea8f218a 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -36,26 +36,35 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
 
   let methods = [
     InterfaceMethod<[{
-        Returns a mutable range of operands that correspond to the arguments of
-        successor at the given index. Returns None if the operands to the
-        successor are non-materialized values, i.e. they are internal to the
-        operation.
+        Returns the operands that correspond to the arguments of the successor
+        at the given index. It consists of a number of operands that are
+        internally produced by the operation, followed by a range of operands
+        that are forwarded. An example operation making use of produced
+        operands would be:
+
+        ```mlir
+        invoke %function(%0)
+            label ^success ^error(%1 : i32)
+
+        ^error(%e: !error, %arg0: i32):
+            ...
+        ```
+
+        The operand that would map to the `^error`s `%e` operand is produced
+        by the `invoke` operation, while `%1` is a forwarded operand that maps
+        to `%arg0` in the successor.
+
+        Produced operands always map to the first few block arguments of the
+        successor, followed by the forwarded operands. Mapping them in any
+        other order is not supported by the interface.
+
+        By having the forwarded operands last allows users of the interface
+        to append more forwarded operands to the branch operation without
+        interfering with other successor operands.
       }],
-      "::mlir::Optional<::mlir::MutableOperandRange>", "getMutableSuccessorOperands",
+      "::mlir::SuccessorOperands", "getSuccessorOperands",
       (ins "unsigned":$index)
     >,
-    InterfaceMethod<[{
-        Returns a range of operands that correspond to the arguments of
-        successor at the given index. Returns None if the operands to the
-        successor are non-materialized values, i.e. they are internal to the
-        operation.
-      }],
-      "::mlir::Optional<::mlir::OperandRange>", "getSuccessorOperands",
-      (ins "unsigned":$index), [{}], [{
-        auto operands = $_op.getMutableSuccessorOperands(index);
-        return operands ? ::mlir::Optional<::mlir::OperandRange>(*operands) : ::llvm::None;
-      }]
-    >,
     InterfaceMethod<[{
         Returns the `BlockArgument` corresponding to operand `operandIndex` in
         some successor, or None if `operandIndex` isn't a successor operand
@@ -94,7 +103,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
   let verify = [{
     auto concreteOp = ::mlir::cast<ConcreteOp>($_op);
     for (unsigned i = 0, e = $_op->getNumSuccessors(); i != e; ++i) {
-      ::mlir::Optional<OperandRange> operands = concreteOp.getSuccessorOperands(i);
+      ::mlir::SuccessorOperands operands = concreteOp.getSuccessorOperands(i);
       if (::mlir::failed(::mlir::detail::verifyBranchSuccessorOperands($_op, i, operands)))
         return ::mlir::failure();
     }

diff  --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index e3b09bcf5888c..78eb0e414bdfa 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -149,14 +149,13 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth,
 
       // Try to get the operand passed for this argument.
       unsigned index = it.getSuccessorIndex();
-      Optional<OperandRange> operands = branch.getSuccessorOperands(index);
-      if (!operands) {
+      Value operand = branch.getSuccessorOperands(index)[argNumber];
+      if (!operand) {
         // We can't analyze the control flow, so bail out early.
         output.push_back(arg);
         return;
       }
-      collectUnderlyingAddressValues((*operands)[argNumber], maxDepth, visited,
-                                     output);
+      collectUnderlyingAddressValues(operand, maxDepth, visited, output);
     }
     return;
   }

diff  --git a/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp b/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp
index 45766a25f791b..5b2b31db29498 100644
--- a/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp
@@ -70,10 +70,10 @@ void BufferViewFlowAnalysis::build(Operation *op) {
       // Query the branch op interface to get the successor operands.
       auto successorOperands =
           branchInterface.getSuccessorOperands(it.getIndex());
-      if (!successorOperands.hasValue())
-        continue;
       // Build the actual mapping of values to their immediate dependencies.
-      registerDependencies(successorOperands.getValue(), (*it)->getArguments());
+      registerDependencies(successorOperands.getForwardedOperands(),
+                           (*it)->getArguments().drop_front(
+                               successorOperands.getProducedOperandCount()));
     }
   });
 

diff  --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp
index b8e801fea6db8..6718dee107fe5 100644
--- a/mlir/lib/Analysis/DataFlowAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp
@@ -681,10 +681,13 @@ void ForwardDataFlowSolver::visitBlockArgument(Block *block, int i) {
     // Try to get the operand forwarded by the predecessor. If we can't reason
     // about the terminator of the predecessor, mark as having reached a
     // fixpoint.
-    Optional<OperandRange> branchOperands;
-    if (auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator()))
-      branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex());
-    if (!branchOperands) {
+    auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator());
+    if (!branch) {
+      updatedLattice |= argLattice.markPessimisticFixpoint();
+      break;
+    }
+    Value operand = branch.getSuccessorOperands(it.getSuccessorIndex())[i];
+    if (!operand) {
       updatedLattice |= argLattice.markPessimisticFixpoint();
       break;
     }
@@ -692,7 +695,7 @@ void ForwardDataFlowSolver::visitBlockArgument(Block *block, int i) {
     // If the operand hasn't been resolved, it is uninitialized and can merge
     // with anything.
     AbstractLatticeElement *operandLattice =
-        analysis.lookupLatticeElement((*branchOperands)[i]);
+        analysis.lookupLatticeElement(operand);
     if (!operandLattice)
       continue;
 

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index 99ce070e94000..5d9dd6d1b7b61 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -325,25 +325,20 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
       // argument.
       Operation *terminator = (*it)->getTerminator();
       auto branchInterface = cast<BranchOpInterface>(terminator);
+      SuccessorOperands operands =
+          branchInterface.getSuccessorOperands(it.getSuccessorIndex());
+
       // Query the associated source value.
-      Value sourceValue =
-          branchInterface.getSuccessorOperands(it.getSuccessorIndex())
-              .getValue()[blockArg.getArgNumber()];
-      // Wire new clone and successor operand.
-      auto mutableOperands =
-          branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex());
-      if (!mutableOperands) {
-        terminator->emitError() << "terminators with immutable successor "
-                                   "operands are not supported";
-        continue;
+      Value sourceValue = operands[blockArg.getArgNumber()];
+      if (!sourceValue) {
+        return failure();
       }
+      // Wire new clone and successor operand.
       // Create a new clone at the current location of the terminator.
       auto clone = introduceCloneBuffers(sourceValue, terminator);
       if (failed(clone))
         return failure();
-      mutableOperands.getValue()
-          .slice(blockArg.getArgNumber(), 1)
-          .assign(*clone);
+      operands.slice(blockArg.getArgNumber(), 1).assign(*clone);
     }
 
     // Check whether the block argument has implicitly defined predecessors via

diff  --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 03f0998ac85fc..9085ce7e86e89 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -186,10 +186,9 @@ void BranchOp::setDest(Block *block) { return setSuccessor(block); }
 
 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
 
-Optional<MutableOperandRange>
-BranchOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
   assert(index == 0 && "invalid successor index");
-  return getDestOperandsMutable();
+  return SuccessorOperands(getDestOperandsMutable());
 }
 
 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
@@ -437,11 +436,10 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
               CondBranchTruthPropagation>(context);
 }
 
-Optional<MutableOperandRange>
-CondBranchOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
   assert(index < getNumSuccessors() && "invalid successor index");
-  return index == trueIndex ? getTrueDestOperandsMutable()
-                            : getFalseDestOperandsMutable();
+  return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
+                                              : getFalseDestOperandsMutable());
 }
 
 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
@@ -575,11 +573,10 @@ LogicalResult SwitchOp::verify() {
   return success();
 }
 
-Optional<MutableOperandRange>
-SwitchOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
   assert(index < getNumSuccessors() && "invalid successor index");
-  return index == 0 ? getDefaultOperandsMutable()
-                    : getCaseOperandsMutable(index - 1);
+  return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
+                                      : getCaseOperandsMutable(index - 1));
 }
 
 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {

diff  --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index c1e69d0ed0ba8..7058b72b740d3 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -67,12 +67,13 @@ class BranchOpInterfaceTypeConversion
     SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
     for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
          succIdx < succEnd; ++succIdx) {
-      auto successorOperands = op.getSuccessorOperands(succIdx);
-      if (!successorOperands || successorOperands->empty())
+      OperandRange forwardedOperands =
+          op.getSuccessorOperands(succIdx).getForwardedOperands();
+      if (forwardedOperands.empty())
         continue;
 
-      for (int idx = successorOperands->getBeginOperandIndex(),
-               eidx = idx + successorOperands->size();
+      for (int idx = forwardedOperands.getBeginOperandIndex(),
+               eidx = idx + forwardedOperands.size();
            idx < eidx; ++idx) {
         if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx))
           newOperands[idx] = operands[idx];
@@ -121,8 +122,8 @@ bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
   if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
     for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
       auto successorOperands = branchOp.getSuccessorOperands(p);
-      if (successorOperands.hasValue() &&
-          !converter.isLegal(successorOperands.getValue().getTypes()))
+      if (!converter.isLegal(
+              successorOperands.getForwardedOperands().getTypes()))
         return false;
     }
     return true;

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index ff93a506b4b53..e149667659a97 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -240,21 +240,19 @@ ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
 // LLVM::BrOp
 //===----------------------------------------------------------------------===//
 
-Optional<MutableOperandRange>
-BrOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {
   assert(index == 0 && "invalid successor index");
-  return getDestOperandsMutable();
+  return SuccessorOperands(getDestOperandsMutable());
 }
 
 //===----------------------------------------------------------------------===//
 // LLVM::CondBrOp
 //===----------------------------------------------------------------------===//
 
-Optional<MutableOperandRange>
-CondBrOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) {
   assert(index < getNumSuccessors() && "invalid successor index");
-  return index == 0 ? getTrueDestOperandsMutable()
-                    : getFalseDestOperandsMutable();
+  return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable()
+                                      : getFalseDestOperandsMutable());
 }
 
 //===----------------------------------------------------------------------===//
@@ -356,11 +354,10 @@ LogicalResult SwitchOp::verify() {
   return success();
 }
 
-Optional<MutableOperandRange>
-SwitchOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
   assert(index < getNumSuccessors() && "invalid successor index");
-  return index == 0 ? getDefaultOperandsMutable()
-                    : getCaseOperandsMutable(index - 1);
+  return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
+                                      : getCaseOperandsMutable(index - 1));
 }
 
 //===----------------------------------------------------------------------===//
@@ -735,11 +732,10 @@ ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
 /// LLVM::InvokeOp
 ///===---------------------------------------------------------------------===//
 
-Optional<MutableOperandRange>
-InvokeOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
   assert(index < getNumSuccessors() && "invalid successor index");
-  return index == 0 ? getNormalDestOperandsMutable()
-                    : getUnwindDestOperandsMutable();
+  return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
+                                      : getUnwindDestOperandsMutable());
 }
 
 LogicalResult InvokeOp::verify() {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 25aa5396686e1..9a31622b52f0b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -223,12 +223,12 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
           auto blockOperands =
               terminator.getSuccessorOperands(pred.getSuccessorIndex());
 
-          if (!blockOperands || blockOperands->empty())
+          if (blockOperands.empty() ||
+              blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
             continue;
 
           detensorableBranchOps[terminator].insert(
-              blockOperands->getBeginOperandIndex() +
-              blockArgumentElem.getArgNumber());
+              blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
         }
       }
 
@@ -343,14 +343,15 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
             auto ownerBlockOperands =
                 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
 
-            if (!ownerBlockOperands || ownerBlockOperands->empty())
+            if (ownerBlockOperands.empty() ||
+                ownerBlockOperands.isOperandProduced(
+                    currentItemBlockArgument.getArgNumber()))
               continue;
 
             // For each predecessor, add the value it passes to that argument to
             // workList to find out how it's computed.
             workList.push_back(
-                ownerBlockOperands
-                    .getValue()[currentItemBlockArgument.getArgNumber()]);
+                ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
           }
 
           continue;
@@ -418,18 +419,16 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
           auto blockOperands =
               terminator.getSuccessorOperands(pred.getSuccessorIndex());
 
-          if (!blockOperands || blockOperands->empty())
+          if (blockOperands.empty() ||
+              blockOperands.isOperandProduced(blockArg.getArgNumber()))
             continue;
 
           Operation *definingOp =
-              terminator
-                  ->getOperand(blockOperands->getBeginOperandIndex() +
-                               blockArg.getArgNumber())
-                  .getDefiningOp();
+              blockOperands[blockArg.getArgNumber()].getDefiningOp();
 
           // If the operand is defined by a GenericOp that will not be
           // detensored, then do not detensor the corresponding block argument.
-          if (dyn_cast_or_null<GenericOp>(definingOp) &&
+          if (isa_and_nonnull<GenericOp>(definingOp) &&
               opsToDetensor.count(definingOp) == 0) {
             blockArgsToRemove.insert(blockArg);
             break;

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 4ffde49807e83..14d7f78f243e2 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1515,21 +1515,20 @@ LogicalResult spirv::BitcastOp::verify() {
 // spv.BranchOp
 //===----------------------------------------------------------------------===//
 
-Optional<MutableOperandRange>
-spirv::BranchOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
   assert(index == 0 && "invalid successor index");
-  return targetOperandsMutable();
+  return SuccessorOperands(0, targetOperandsMutable());
 }
 
 //===----------------------------------------------------------------------===//
 // spv.BranchConditionalOp
 //===----------------------------------------------------------------------===//
 
-Optional<MutableOperandRange>
-spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands
+spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
   assert(index < 2 && "invalid successor index");
-  return index == kTrueIndex ? trueTargetOperandsMutable()
-                             : falseTargetOperandsMutable();
+  return SuccessorOperands(index == kTrueIndex ? trueTargetOperandsMutable()
+                                               : falseTargetOperandsMutable());
 }
 
 ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,

diff  --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 02845c011472a..69ed30ae7bdd5 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -18,6 +18,14 @@ using namespace mlir;
 
 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
 
+SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
+    : producedOperandCount(0), forwardedOperands(forwardedOperands) {}
+
+SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
+                                     MutableOperandRange forwardedOperands)
+    : producedOperandCount(producedOperandCount),
+      forwardedOperands(std::move(forwardedOperands)) {}
+
 //===----------------------------------------------------------------------===//
 // BranchOpInterface
 //===----------------------------------------------------------------------===//
@@ -26,32 +34,31 @@ using namespace mlir;
 /// successor if 'operandIndex' is within the range of 'operands', or None if
 /// `operandIndex` isn't a successor operand index.
 Optional<BlockArgument>
-detail::getBranchSuccessorArgument(Optional<OperandRange> operands,
+detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
                                    unsigned operandIndex, Block *successor) {
+  OperandRange forwardedOperands = operands.getForwardedOperands();
   // Check that the operands are valid.
-  if (!operands || operands->empty())
+  if (forwardedOperands.empty())
     return llvm::None;
 
   // Check to ensure that this operand is within the range.
-  unsigned operandsStart = operands->getBeginOperandIndex();
+  unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
   if (operandIndex < operandsStart ||
-      operandIndex >= (operandsStart + operands->size()))
+      operandIndex >= (operandsStart + forwardedOperands.size()))
     return llvm::None;
 
   // Index the successor.
-  unsigned argIndex = operandIndex - operandsStart;
+  unsigned argIndex =
+      operands.getProducedOperandCount() + operandIndex - operandsStart;
   return successor->getArgument(argIndex);
 }
 
 /// Verify that the given operands match those of the given successor block.
 LogicalResult
 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
-                                      Optional<OperandRange> operands) {
-  if (!operands)
-    return success();
-
+                                      const SuccessorOperands &operands) {
   // Check the count.
-  unsigned operandCount = operands->size();
+  unsigned operandCount = operands.size();
   Block *destBB = op->getSuccessor(succNo);
   if (operandCount != destBB->getNumArguments())
     return op->emitError() << "branch has " << operandCount
@@ -60,10 +67,10 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
                            << destBB->getNumArguments();
 
   // Check the types.
-  auto operandIt = operands->begin();
-  for (unsigned i = 0; i != operandCount; ++i, ++operandIt) {
+  for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
+       ++i) {
     if (!cast<BranchOpInterface>(op).areTypesCompatible(
-            (*operandIt).getType(), destBB->getArgument(i).getType()))
+            operands[i].getType(), destBB->getArgument(i).getType()))
       return op->emitError() << "type mismatch for bb argument #" << i
                              << " of successor #" << succNo;
   }

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index b1cdb3554c3fc..953fb2461520f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -441,10 +441,9 @@ static Value getPHISourceValue(Block *current, Block *pred,
   for (unsigned i = 0, e = terminator.getNumSuccessors(); i < e; ++i) {
     Block *successor = terminator.getSuccessor(i);
     auto branch = cast<BranchOpInterface>(terminator);
-    Optional<OperandRange> successorOperands = branch.getSuccessorOperands(i);
+    SuccessorOperands successorOperands = branch.getSuccessorOperands(i);
     assert(
-        (!seenSuccessors.contains(successor) ||
-         (successorOperands && successorOperands->empty())) &&
+        (!seenSuccessors.contains(successor) || successorOperands.empty()) &&
         "successors with arguments in LLVM branches must be 
diff erent blocks");
     seenSuccessors.insert(successor);
   }

diff  --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 6ee3266f6a739..996588243f565 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -223,12 +223,14 @@ static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
     return;
   }
 
-  // If we can't reason about the operands to a successor, conservatively mark
-  // all arguments as live.
+  // If we can't reason about the operand to a successor, conservatively mark
+  // it as live.
   for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
-    if (!branchInterface.getMutableSuccessorOperands(i))
-      for (BlockArgument arg : op->getSuccessor(i)->getArguments())
-        liveMap.setProvedLive(arg);
+    SuccessorOperands successorOperands =
+        branchInterface.getSuccessorOperands(i);
+    for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount();
+         opI != opE; ++opI)
+      liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI));
   }
 }
 
@@ -291,18 +293,15 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator,
     // since it will promote later operands of the terminator being erased
     // first, reducing the quadratic-ness.
     unsigned succ = succE - succI - 1;
-    Optional<MutableOperandRange> succOperands =
-        branchOp.getMutableSuccessorOperands(succ);
-    if (!succOperands)
-      continue;
+    SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ);
     Block *successor = terminator->getSuccessor(succ);
 
-    for (unsigned argI = 0, argE = succOperands->size(); argI < argE; ++argI) {
+    for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) {
       // Iterating args in reverse is needed for correctness, to avoid
       // shifting later args when earlier args are erased.
       unsigned arg = argE - argI - 1;
       if (!liveMap.wasProvenLive(successor->getArgument(arg)))
-        succOperands->erase(arg);
+        succOperands.erase(arg);
     }
   }
 }
@@ -570,8 +569,7 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
 /// their operands updated.
 static bool ableToUpdatePredOperands(Block *block) {
   for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
-    auto branch = dyn_cast<BranchOpInterface>((*it)->getTerminator());
-    if (!branch || !branch.getMutableSuccessorOperands(it.getSuccessorIndex()))
+    if (!isa<BranchOpInterface>((*it)->getTerminator()))
       return false;
   }
   return true;
@@ -631,7 +629,7 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
            predIt != predE; ++predIt) {
         auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
         unsigned succIndex = predIt.getSuccessorIndex();
-        branch.getMutableSuccessorOperands(succIndex)->append(
+        branch.getSuccessorOperands(succIndex).append(
             newArguments[clusterIndex]);
       }
     };

diff  --git a/mlir/test/Transforms/sccp.mlir b/mlir/test/Transforms/sccp.mlir
index 4879ee8c54c40..a77fbe7a61a8c 100644
--- a/mlir/test/Transforms/sccp.mlir
+++ b/mlir/test/Transforms/sccp.mlir
@@ -198,3 +198,21 @@ func @recheck_executable_edge(%cond0: i1) -> (i1, i1) {
   // CHECK: return %[[X]], %[[Y]]
   return %x, %y : i1, i1
 }
+
+// CHECK-LABEL: func @simple_produced_operand
+func @simple_produced_operand() -> (i32, i32) {
+  // CHECK: %[[ONE:.*]] = arith.constant 1
+  %1 = arith.constant 1 : i32
+  "test.internal_br"(%1) [^bb1, ^bb2] {
+    operand_segment_sizes = dense<[0, 1]> : vector<2 x i32>
+  } : (i32) -> ()
+
+^bb1:
+  cf.br ^bb2(%1, %1 : i32, i32)
+
+^bb2(%arg1 : i32, %arg2 : i32):
+  // CHECK: ^bb2(%[[ARG:.*]]: i32, %{{.*}}: i32):
+  // CHECK: return %[[ARG]], %[[ONE]] : i32, i32
+
+  return %arg1, %arg2 : i32, i32
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 73119805fcdf0..1f496ee2b09e3 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -335,22 +335,31 @@ TestDialect::getOperationPrinter(Operation *op) const {
 // TestBranchOp
 //===----------------------------------------------------------------------===//
 
-Optional<MutableOperandRange>
-TestBranchOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
   assert(index == 0 && "invalid successor index");
-  return getTargetOperandsMutable();
+  return SuccessorOperands(getTargetOperandsMutable());
 }
 
 //===----------------------------------------------------------------------===//
 // TestProducingBranchOp
 //===----------------------------------------------------------------------===//
 
-Optional<MutableOperandRange>
-TestProducingBranchOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
   assert(index <= 1 && "invalid successor index");
   if (index == 1)
-    return getFirstOperandsMutable();
-  return getSecondOperandsMutable();
+    return SuccessorOperands(getFirstOperandsMutable());
+  return SuccessorOperands(getSecondOperandsMutable());
+}
+
+//===----------------------------------------------------------------------===//
+// TestProducingBranchOp
+//===----------------------------------------------------------------------===//
+
+SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
+  assert(index <= 1 && "invalid successor index");
+  if (index == 0)
+    return SuccessorOperands(0, getSuccessOperandsMutable());
+  return SuccessorOperands(1, getErrorOperandsMutable());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9902b57323ca3..bccca927725e0 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -642,6 +642,17 @@ def TestProducingBranchOp : TEST_Op<"producing_br",
   let successors = (successor AnySuccessor:$first,AnySuccessor:$second);
 }
 
+// Produces an error value on the error path
+def TestInternalBranchOp : TEST_Op<"internal_br",
+	[DeclareOpInterfaceMethods<BranchOpInterface>, Terminator,
+	 AttrSizedOperandSegments]> {
+
+  let arguments = (ins Variadic<AnyType>:$successOperands,
+                       Variadic<AnyType>:$errorOperands);
+
+  let successors = (successor AnySuccessor:$successPath, AnySuccessor:$errorPath);
+}
+
 def AttrSizedOperandOp : TEST_Op<"attr_sized_operands",
                                  [AttrSizedOperandSegments]> {
   let arguments = (ins


        


More information about the Mlir-commits mailing list