[Mlir-commits] [mlir] 0c789db - [mlir] Add support for operation-produced successor arguments in BranchOpInterface
Markus Böck
llvmlistbot at llvm.org
Thu Apr 7 23:28:36 PDT 2022
Author: Markus Böck
Date: 2022-04-08T08:28:16+02:00
New Revision: 0c789db541c236abf47265331a2f2b0945aa7b93
URL: https://github.com/llvm/llvm-project/commit/0c789db541c236abf47265331a2f2b0945aa7b93
DIFF: https://github.com/llvm/llvm-project/commit/0c789db541c236abf47265331a2f2b0945aa7b93.diff
LOG: [mlir] Add support for operation-produced successor arguments in BranchOpInterface
This patch revamps the BranchOpInterface a bit and allows a proper implementation of what was previously `getMutableSuccessorOperands` for operations, which internally produce arguments to some of the block arguments. A motivating example for this would be an invoke op with a error handling path:
```
invoke %function(%0)
label ^success ^error(%1 : i32)
^error(%e: !error, %arg0 : i32):
...
```
The advantages of this are that any users of `BranchOpInterface` can still argue over remaining block argument operands (such as `%1` in the example above), as well as make use of the modifying capabilities to add more operands, erase an operand etc.
The way this patch implements that functionality is via a new class called `SuccessorOperands`, which is now returned by `getSuccessorOperands`. It basically contains an `unsigned` denoting how many operator produced operands exist, as well as a `MutableOperandRange`, which are the usual forwarded operands we are used to. The produced operands are assumed to the first few block arguments, followed by the forwarded operands afterwards. The role of `SuccessorOperands` is to provide various utility functions to modify and query the successor arguments from a `BranchOpInterface`.
Differential Revision: https://reviews.llvm.org/D123062
Added:
Modified:
flang/include/flang/Optimizer/Dialect/FIROps.td
flang/lib/Optimizer/Dialect/FIROps.cpp
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
mlir/lib/Analysis/BufferViewFlowAnalysis.cpp
mlir/lib/Analysis/DataFlowAnalysis.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Interfaces/ControlFlowInterfaces.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/lib/Transforms/Utils/RegionUtils.cpp
mlir/test/Transforms/sccp.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 6eb0fdfea669d..e0c09396dc7c9 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -489,16 +489,12 @@ class fir_SwitchTerminatorOp<string mnemonic, list<Trait> traits = []> :
llvm::ArrayRef<mlir::Value> operands, unsigned cond);
llvm::Optional<mlir::ValueRange> getSuccessorOperands(
mlir::ValueRange operands, unsigned cond);
- using BranchOpInterfaceTrait::getSuccessorOperands;
// Helper function to deal with Optional operand forms
void printSuccessorAtIndex(mlir::OpAsmPrinter &p, unsigned i) {
auto *succ = getSuccessor(i);
auto ops = getSuccessorOperands(i);
- if (ops.hasValue())
- p.printSuccessorAndUseList(succ, ops.getValue());
- else
- p.printSuccessor(succ);
+ p.printSuccessorAndUseList(succ, ops.getForwardedOperands());
}
mlir::ArrayAttr getCases() {
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 6dfc5a90d8fcb..1b48a56adbcaf 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -2401,10 +2401,9 @@ fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
return {};
}
-llvm::Optional<mlir::MutableOperandRange>
-fir::SelectOp::getMutableSuccessorOperands(unsigned oper) {
- return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(),
- getTargetOffsetAttr());
+mlir::SuccessorOperands fir::SelectOp::getSuccessorOperands(unsigned oper) {
+ return mlir::SuccessorOperands(::getMutableSuccessorOperands(
+ oper, getTargetArgsMutable(), getTargetOffsetAttr()));
}
llvm::Optional<llvm::ArrayRef<mlir::Value>>
@@ -2462,10 +2461,9 @@ fir::SelectCaseOp::getCompareOperands(mlir::ValueRange operands,
return {getSubOperands(cond, getSubOperands(1, operands, segments), a)};
}
-llvm::Optional<mlir::MutableOperandRange>
-fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) {
- return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(),
- getTargetOffsetAttr());
+mlir::SuccessorOperands fir::SelectCaseOp::getSuccessorOperands(unsigned oper) {
+ return mlir::SuccessorOperands(::getMutableSuccessorOperands(
+ oper, getTargetArgsMutable(), getTargetOffsetAttr()));
}
llvm::Optional<llvm::ArrayRef<mlir::Value>>
@@ -2734,10 +2732,9 @@ fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
return {};
}
-llvm::Optional<mlir::MutableOperandRange>
-fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) {
- return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(),
- getTargetOffsetAttr());
+mlir::SuccessorOperands fir::SelectRankOp::getSuccessorOperands(unsigned oper) {
+ return mlir::SuccessorOperands(::getMutableSuccessorOperands(
+ oper, getTargetArgsMutable(), getTargetOffsetAttr()));
}
llvm::Optional<llvm::ArrayRef<mlir::Value>>
@@ -2779,10 +2776,9 @@ fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
return {};
}
-llvm::Optional<mlir::MutableOperandRange>
-fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) {
- return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(),
- getTargetOffsetAttr());
+mlir::SuccessorOperands fir::SelectTypeOp::getSuccessorOperands(unsigned oper) {
+ return mlir::SuccessorOperands(::getMutableSuccessorOperands(
+ oper, getTargetArgsMutable(), getTargetOffsetAttr()));
}
llvm::Optional<llvm::ArrayRef<mlir::Value>>
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 3707747b5bff0..22cf6fb2423a7 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -907,6 +907,11 @@ class MutableOperandRange {
/// elements attribute, which contains the sizes of the sub ranges.
MutableOperandRangeRange split(NamedAttribute segmentSizes) const;
+ /// Returns the value at the given index.
+ Value operator[](unsigned index) const {
+ return static_cast<OperandRange>(*this)[index];
+ }
+
private:
/// Update the length of this range to the one provided.
void updateLength(unsigned newLength);
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 1e8f7b54c474a..3fc73de2c0cd6 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -20,6 +20,106 @@ namespace mlir {
class BranchOpInterface;
class RegionBranchOpInterface;
+/// This class models how operands are forwarded to block arguments in control
+/// flow. It consists of a number, denoting how many of the successors block
+/// arguments are produced by the operation, followed by a range of operands
+/// that are forwarded. The produced operands are passed to the first few
+/// block arguments of the successor, followed by the forwarded operands.
+/// It is unsupported to pass them in a
diff erent order.
+///
+/// An example operation with both of these concepts would be a branch-on-error
+/// operation, that internally produces an error object on the error path:
+///
+/// invoke %function(%0)
+/// label ^success ^error(%1 : i32)
+///
+/// ^error(%e: !error, %arg0 : i32):
+/// ...
+///
+/// This operation would return an instance of SuccessorOperands with a produced
+/// operand count of 1 (mapped to %e in the successor) and a forwarded
+/// operands range consisting of %1 in the example above (mapped to %arg0 in the
+/// successor).
+class SuccessorOperands {
+public:
+ /// Constructs a SuccessorOperands with no produced operands that simply
+ /// forwards operands to the successor.
+ explicit SuccessorOperands(MutableOperandRange forwardedOperands);
+
+ /// Constructs a SuccessorOperands with the given amount of produced operands
+ /// and forwarded operands.
+ SuccessorOperands(unsigned producedOperandCount,
+ MutableOperandRange forwardedOperands);
+
+ /// Returns the amount of operands passed to the successor. This consists both
+ /// of produced operands by the operation as well as forwarded ones.
+ unsigned size() const {
+ return producedOperandCount + forwardedOperands.size();
+ }
+
+ /// Returns true if there are no successor operands.
+ bool empty() const { return size() == 0; }
+
+ /// Returns the amount of operands that are produced internally by the
+ /// operation. These are passed to the first few block arguments.
+ unsigned getProducedOperandCount() const { return producedOperandCount; }
+
+ /// Returns true if the successor operand denoted by `index` is produced by
+ /// the operation.
+ bool isOperandProduced(unsigned index) const {
+ return index < producedOperandCount;
+ }
+
+ /// Returns the Value that is passed to the successors block argument denoted
+ /// by `index`. If it is produced by the operation, no such value exists and
+ /// a null Value is returned.
+ Value operator[](unsigned index) const {
+ if (isOperandProduced(index))
+ return Value();
+ return forwardedOperands[index - producedOperandCount];
+ }
+
+ /// Get the range of operands that are simply forwarded to the successor.
+ OperandRange getForwardedOperands() const { return forwardedOperands; }
+
+ /// Get a slice of the operands forwarded to the successor. The given range
+ /// must not contain any operands produced by the operation.
+ MutableOperandRange slice(unsigned subStart, unsigned subLen) const {
+ assert(!isOperandProduced(subStart) &&
+ "can't slice operands produced by the operation");
+ return forwardedOperands.slice(subStart - producedOperandCount, subLen);
+ }
+
+ /// Erase operands forwarded to the successor. The given range must
+ /// not contain any operands produced by the operation.
+ void erase(unsigned subStart, unsigned subLen = 1) {
+ assert(!isOperandProduced(subStart) &&
+ "can't erase operands produced by the operation");
+ forwardedOperands.erase(subStart - producedOperandCount, subLen);
+ }
+
+ /// Add new operands that are forwarded to the successor.
+ void append(ValueRange valueRange) { forwardedOperands.append(valueRange); }
+
+ /// Gets the index of the forwarded operand within the operation which maps
+ /// to the block argument denoted by `blockArgumentIndex`. The block argument
+ /// must be mapped to a forwarded operand.
+ unsigned getOperandIndex(unsigned blockArgumentIndex) const {
+ assert(!isOperandProduced(blockArgumentIndex) &&
+ "can't map operand produced by the operation");
+ return static_cast<mlir::OperandRange>(forwardedOperands)
+ .getBeginOperandIndex() +
+ (blockArgumentIndex - producedOperandCount);
+ }
+
+private:
+ /// Amount of operands that are produced internally within the operation and
+ /// passed to the first few block arguments.
+ unsigned producedOperandCount;
+ /// Range of operands that are forwarded to the remaining block arguments.
+ MutableOperandRange forwardedOperands;
+};
+
//===----------------------------------------------------------------------===//
// BranchOpInterface
//===----------------------------------------------------------------------===//
@@ -29,12 +129,12 @@ namespace detail {
/// successor if `operandIndex` is within the range of `operands`, or None if
/// `operandIndex` isn't a successor operand index.
Optional<BlockArgument>
-getBranchSuccessorArgument(Optional<OperandRange> operands,
+getBranchSuccessorArgument(const SuccessorOperands &operands,
unsigned operandIndex, Block *successor);
/// Verify that the given operands match those of the given successor block.
LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
- Optional<OperandRange> operands);
+ const SuccessorOperands &operands);
} // namespace detail
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 9d7b43b5e4a47..ac805ea8f218a 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -36,26 +36,35 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
let methods = [
InterfaceMethod<[{
- Returns a mutable range of operands that correspond to the arguments of
- successor at the given index. Returns None if the operands to the
- successor are non-materialized values, i.e. they are internal to the
- operation.
+ Returns the operands that correspond to the arguments of the successor
+ at the given index. It consists of a number of operands that are
+ internally produced by the operation, followed by a range of operands
+ that are forwarded. An example operation making use of produced
+ operands would be:
+
+ ```mlir
+ invoke %function(%0)
+ label ^success ^error(%1 : i32)
+
+ ^error(%e: !error, %arg0: i32):
+ ...
+ ```
+
+ The operand that would map to the `^error`s `%e` operand is produced
+ by the `invoke` operation, while `%1` is a forwarded operand that maps
+ to `%arg0` in the successor.
+
+ Produced operands always map to the first few block arguments of the
+ successor, followed by the forwarded operands. Mapping them in any
+ other order is not supported by the interface.
+
+ By having the forwarded operands last allows users of the interface
+ to append more forwarded operands to the branch operation without
+ interfering with other successor operands.
}],
- "::mlir::Optional<::mlir::MutableOperandRange>", "getMutableSuccessorOperands",
+ "::mlir::SuccessorOperands", "getSuccessorOperands",
(ins "unsigned":$index)
>,
- InterfaceMethod<[{
- Returns a range of operands that correspond to the arguments of
- successor at the given index. Returns None if the operands to the
- successor are non-materialized values, i.e. they are internal to the
- operation.
- }],
- "::mlir::Optional<::mlir::OperandRange>", "getSuccessorOperands",
- (ins "unsigned":$index), [{}], [{
- auto operands = $_op.getMutableSuccessorOperands(index);
- return operands ? ::mlir::Optional<::mlir::OperandRange>(*operands) : ::llvm::None;
- }]
- >,
InterfaceMethod<[{
Returns the `BlockArgument` corresponding to operand `operandIndex` in
some successor, or None if `operandIndex` isn't a successor operand
@@ -94,7 +103,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
let verify = [{
auto concreteOp = ::mlir::cast<ConcreteOp>($_op);
for (unsigned i = 0, e = $_op->getNumSuccessors(); i != e; ++i) {
- ::mlir::Optional<OperandRange> operands = concreteOp.getSuccessorOperands(i);
+ ::mlir::SuccessorOperands operands = concreteOp.getSuccessorOperands(i);
if (::mlir::failed(::mlir::detail::verifyBranchSuccessorOperands($_op, i, operands)))
return ::mlir::failure();
}
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index e3b09bcf5888c..78eb0e414bdfa 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -149,14 +149,13 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth,
// Try to get the operand passed for this argument.
unsigned index = it.getSuccessorIndex();
- Optional<OperandRange> operands = branch.getSuccessorOperands(index);
- if (!operands) {
+ Value operand = branch.getSuccessorOperands(index)[argNumber];
+ if (!operand) {
// We can't analyze the control flow, so bail out early.
output.push_back(arg);
return;
}
- collectUnderlyingAddressValues((*operands)[argNumber], maxDepth, visited,
- output);
+ collectUnderlyingAddressValues(operand, maxDepth, visited, output);
}
return;
}
diff --git a/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp b/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp
index 45766a25f791b..5b2b31db29498 100644
--- a/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp
@@ -70,10 +70,10 @@ void BufferViewFlowAnalysis::build(Operation *op) {
// Query the branch op interface to get the successor operands.
auto successorOperands =
branchInterface.getSuccessorOperands(it.getIndex());
- if (!successorOperands.hasValue())
- continue;
// Build the actual mapping of values to their immediate dependencies.
- registerDependencies(successorOperands.getValue(), (*it)->getArguments());
+ registerDependencies(successorOperands.getForwardedOperands(),
+ (*it)->getArguments().drop_front(
+ successorOperands.getProducedOperandCount()));
}
});
diff --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp
index b8e801fea6db8..6718dee107fe5 100644
--- a/mlir/lib/Analysis/DataFlowAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp
@@ -681,10 +681,13 @@ void ForwardDataFlowSolver::visitBlockArgument(Block *block, int i) {
// Try to get the operand forwarded by the predecessor. If we can't reason
// about the terminator of the predecessor, mark as having reached a
// fixpoint.
- Optional<OperandRange> branchOperands;
- if (auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator()))
- branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex());
- if (!branchOperands) {
+ auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator());
+ if (!branch) {
+ updatedLattice |= argLattice.markPessimisticFixpoint();
+ break;
+ }
+ Value operand = branch.getSuccessorOperands(it.getSuccessorIndex())[i];
+ if (!operand) {
updatedLattice |= argLattice.markPessimisticFixpoint();
break;
}
@@ -692,7 +695,7 @@ void ForwardDataFlowSolver::visitBlockArgument(Block *block, int i) {
// If the operand hasn't been resolved, it is uninitialized and can merge
// with anything.
AbstractLatticeElement *operandLattice =
- analysis.lookupLatticeElement((*branchOperands)[i]);
+ analysis.lookupLatticeElement(operand);
if (!operandLattice)
continue;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index 99ce070e94000..5d9dd6d1b7b61 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -325,25 +325,20 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
// argument.
Operation *terminator = (*it)->getTerminator();
auto branchInterface = cast<BranchOpInterface>(terminator);
+ SuccessorOperands operands =
+ branchInterface.getSuccessorOperands(it.getSuccessorIndex());
+
// Query the associated source value.
- Value sourceValue =
- branchInterface.getSuccessorOperands(it.getSuccessorIndex())
- .getValue()[blockArg.getArgNumber()];
- // Wire new clone and successor operand.
- auto mutableOperands =
- branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex());
- if (!mutableOperands) {
- terminator->emitError() << "terminators with immutable successor "
- "operands are not supported";
- continue;
+ Value sourceValue = operands[blockArg.getArgNumber()];
+ if (!sourceValue) {
+ return failure();
}
+ // Wire new clone and successor operand.
// Create a new clone at the current location of the terminator.
auto clone = introduceCloneBuffers(sourceValue, terminator);
if (failed(clone))
return failure();
- mutableOperands.getValue()
- .slice(blockArg.getArgNumber(), 1)
- .assign(*clone);
+ operands.slice(blockArg.getArgNumber(), 1).assign(*clone);
}
// Check whether the block argument has implicitly defined predecessors via
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 03f0998ac85fc..9085ce7e86e89 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -186,10 +186,9 @@ void BranchOp::setDest(Block *block) { return setSuccessor(block); }
void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
-Optional<MutableOperandRange>
-BranchOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
- return getDestOperandsMutable();
+ return SuccessorOperands(getDestOperandsMutable());
}
Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
@@ -437,11 +436,10 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
CondBranchTruthPropagation>(context);
}
-Optional<MutableOperandRange>
-CondBranchOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
- return index == trueIndex ? getTrueDestOperandsMutable()
- : getFalseDestOperandsMutable();
+ return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
+ : getFalseDestOperandsMutable());
}
Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
@@ -575,11 +573,10 @@ LogicalResult SwitchOp::verify() {
return success();
}
-Optional<MutableOperandRange>
-SwitchOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
- return index == 0 ? getDefaultOperandsMutable()
- : getCaseOperandsMutable(index - 1);
+ return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
+ : getCaseOperandsMutable(index - 1));
}
Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index c1e69d0ed0ba8..7058b72b740d3 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -67,12 +67,13 @@ class BranchOpInterfaceTypeConversion
SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
succIdx < succEnd; ++succIdx) {
- auto successorOperands = op.getSuccessorOperands(succIdx);
- if (!successorOperands || successorOperands->empty())
+ OperandRange forwardedOperands =
+ op.getSuccessorOperands(succIdx).getForwardedOperands();
+ if (forwardedOperands.empty())
continue;
- for (int idx = successorOperands->getBeginOperandIndex(),
- eidx = idx + successorOperands->size();
+ for (int idx = forwardedOperands.getBeginOperandIndex(),
+ eidx = idx + forwardedOperands.size();
idx < eidx; ++idx) {
if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx))
newOperands[idx] = operands[idx];
@@ -121,8 +122,8 @@ bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
auto successorOperands = branchOp.getSuccessorOperands(p);
- if (successorOperands.hasValue() &&
- !converter.isLegal(successorOperands.getValue().getTypes()))
+ if (!converter.isLegal(
+ successorOperands.getForwardedOperands().getTypes()))
return false;
}
return true;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index ff93a506b4b53..e149667659a97 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -240,21 +240,19 @@ ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
// LLVM::BrOp
//===----------------------------------------------------------------------===//
-Optional<MutableOperandRange>
-BrOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
- return getDestOperandsMutable();
+ return SuccessorOperands(getDestOperandsMutable());
}
//===----------------------------------------------------------------------===//
// LLVM::CondBrOp
//===----------------------------------------------------------------------===//
-Optional<MutableOperandRange>
-CondBrOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
- return index == 0 ? getTrueDestOperandsMutable()
- : getFalseDestOperandsMutable();
+ return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable()
+ : getFalseDestOperandsMutable());
}
//===----------------------------------------------------------------------===//
@@ -356,11 +354,10 @@ LogicalResult SwitchOp::verify() {
return success();
}
-Optional<MutableOperandRange>
-SwitchOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
- return index == 0 ? getDefaultOperandsMutable()
- : getCaseOperandsMutable(index - 1);
+ return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
+ : getCaseOperandsMutable(index - 1));
}
//===----------------------------------------------------------------------===//
@@ -735,11 +732,10 @@ ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
/// LLVM::InvokeOp
///===---------------------------------------------------------------------===//
-Optional<MutableOperandRange>
-InvokeOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
- return index == 0 ? getNormalDestOperandsMutable()
- : getUnwindDestOperandsMutable();
+ return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
+ : getUnwindDestOperandsMutable());
}
LogicalResult InvokeOp::verify() {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 25aa5396686e1..9a31622b52f0b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -223,12 +223,12 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
auto blockOperands =
terminator.getSuccessorOperands(pred.getSuccessorIndex());
- if (!blockOperands || blockOperands->empty())
+ if (blockOperands.empty() ||
+ blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
continue;
detensorableBranchOps[terminator].insert(
- blockOperands->getBeginOperandIndex() +
- blockArgumentElem.getArgNumber());
+ blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
}
}
@@ -343,14 +343,15 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
auto ownerBlockOperands =
predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
- if (!ownerBlockOperands || ownerBlockOperands->empty())
+ if (ownerBlockOperands.empty() ||
+ ownerBlockOperands.isOperandProduced(
+ currentItemBlockArgument.getArgNumber()))
continue;
// For each predecessor, add the value it passes to that argument to
// workList to find out how it's computed.
workList.push_back(
- ownerBlockOperands
- .getValue()[currentItemBlockArgument.getArgNumber()]);
+ ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
}
continue;
@@ -418,18 +419,16 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
auto blockOperands =
terminator.getSuccessorOperands(pred.getSuccessorIndex());
- if (!blockOperands || blockOperands->empty())
+ if (blockOperands.empty() ||
+ blockOperands.isOperandProduced(blockArg.getArgNumber()))
continue;
Operation *definingOp =
- terminator
- ->getOperand(blockOperands->getBeginOperandIndex() +
- blockArg.getArgNumber())
- .getDefiningOp();
+ blockOperands[blockArg.getArgNumber()].getDefiningOp();
// If the operand is defined by a GenericOp that will not be
// detensored, then do not detensor the corresponding block argument.
- if (dyn_cast_or_null<GenericOp>(definingOp) &&
+ if (isa_and_nonnull<GenericOp>(definingOp) &&
opsToDetensor.count(definingOp) == 0) {
blockArgsToRemove.insert(blockArg);
break;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 4ffde49807e83..14d7f78f243e2 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1515,21 +1515,20 @@ LogicalResult spirv::BitcastOp::verify() {
// spv.BranchOp
//===----------------------------------------------------------------------===//
-Optional<MutableOperandRange>
-spirv::BranchOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
- return targetOperandsMutable();
+ return SuccessorOperands(0, targetOperandsMutable());
}
//===----------------------------------------------------------------------===//
// spv.BranchConditionalOp
//===----------------------------------------------------------------------===//
-Optional<MutableOperandRange>
-spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands
+spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
assert(index < 2 && "invalid successor index");
- return index == kTrueIndex ? trueTargetOperandsMutable()
- : falseTargetOperandsMutable();
+ return SuccessorOperands(index == kTrueIndex ? trueTargetOperandsMutable()
+ : falseTargetOperandsMutable());
}
ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 02845c011472a..69ed30ae7bdd5 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -18,6 +18,14 @@ using namespace mlir;
#include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
+SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
+ : producedOperandCount(0), forwardedOperands(forwardedOperands) {}
+
+SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
+ MutableOperandRange forwardedOperands)
+ : producedOperandCount(producedOperandCount),
+ forwardedOperands(std::move(forwardedOperands)) {}
+
//===----------------------------------------------------------------------===//
// BranchOpInterface
//===----------------------------------------------------------------------===//
@@ -26,32 +34,31 @@ using namespace mlir;
/// successor if 'operandIndex' is within the range of 'operands', or None if
/// `operandIndex` isn't a successor operand index.
Optional<BlockArgument>
-detail::getBranchSuccessorArgument(Optional<OperandRange> operands,
+detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
unsigned operandIndex, Block *successor) {
+ OperandRange forwardedOperands = operands.getForwardedOperands();
// Check that the operands are valid.
- if (!operands || operands->empty())
+ if (forwardedOperands.empty())
return llvm::None;
// Check to ensure that this operand is within the range.
- unsigned operandsStart = operands->getBeginOperandIndex();
+ unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
if (operandIndex < operandsStart ||
- operandIndex >= (operandsStart + operands->size()))
+ operandIndex >= (operandsStart + forwardedOperands.size()))
return llvm::None;
// Index the successor.
- unsigned argIndex = operandIndex - operandsStart;
+ unsigned argIndex =
+ operands.getProducedOperandCount() + operandIndex - operandsStart;
return successor->getArgument(argIndex);
}
/// Verify that the given operands match those of the given successor block.
LogicalResult
detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
- Optional<OperandRange> operands) {
- if (!operands)
- return success();
-
+ const SuccessorOperands &operands) {
// Check the count.
- unsigned operandCount = operands->size();
+ unsigned operandCount = operands.size();
Block *destBB = op->getSuccessor(succNo);
if (operandCount != destBB->getNumArguments())
return op->emitError() << "branch has " << operandCount
@@ -60,10 +67,10 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
<< destBB->getNumArguments();
// Check the types.
- auto operandIt = operands->begin();
- for (unsigned i = 0; i != operandCount; ++i, ++operandIt) {
+ for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
+ ++i) {
if (!cast<BranchOpInterface>(op).areTypesCompatible(
- (*operandIt).getType(), destBB->getArgument(i).getType()))
+ operands[i].getType(), destBB->getArgument(i).getType()))
return op->emitError() << "type mismatch for bb argument #" << i
<< " of successor #" << succNo;
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index b1cdb3554c3fc..953fb2461520f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -441,10 +441,9 @@ static Value getPHISourceValue(Block *current, Block *pred,
for (unsigned i = 0, e = terminator.getNumSuccessors(); i < e; ++i) {
Block *successor = terminator.getSuccessor(i);
auto branch = cast<BranchOpInterface>(terminator);
- Optional<OperandRange> successorOperands = branch.getSuccessorOperands(i);
+ SuccessorOperands successorOperands = branch.getSuccessorOperands(i);
assert(
- (!seenSuccessors.contains(successor) ||
- (successorOperands && successorOperands->empty())) &&
+ (!seenSuccessors.contains(successor) || successorOperands.empty()) &&
"successors with arguments in LLVM branches must be
diff erent blocks");
seenSuccessors.insert(successor);
}
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 6ee3266f6a739..996588243f565 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -223,12 +223,14 @@ static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
return;
}
- // If we can't reason about the operands to a successor, conservatively mark
- // all arguments as live.
+ // If we can't reason about the operand to a successor, conservatively mark
+ // it as live.
for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
- if (!branchInterface.getMutableSuccessorOperands(i))
- for (BlockArgument arg : op->getSuccessor(i)->getArguments())
- liveMap.setProvedLive(arg);
+ SuccessorOperands successorOperands =
+ branchInterface.getSuccessorOperands(i);
+ for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount();
+ opI != opE; ++opI)
+ liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI));
}
}
@@ -291,18 +293,15 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator,
// since it will promote later operands of the terminator being erased
// first, reducing the quadratic-ness.
unsigned succ = succE - succI - 1;
- Optional<MutableOperandRange> succOperands =
- branchOp.getMutableSuccessorOperands(succ);
- if (!succOperands)
- continue;
+ SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ);
Block *successor = terminator->getSuccessor(succ);
- for (unsigned argI = 0, argE = succOperands->size(); argI < argE; ++argI) {
+ for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) {
// Iterating args in reverse is needed for correctness, to avoid
// shifting later args when earlier args are erased.
unsigned arg = argE - argI - 1;
if (!liveMap.wasProvenLive(successor->getArgument(arg)))
- succOperands->erase(arg);
+ succOperands.erase(arg);
}
}
}
@@ -570,8 +569,7 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
/// their operands updated.
static bool ableToUpdatePredOperands(Block *block) {
for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
- auto branch = dyn_cast<BranchOpInterface>((*it)->getTerminator());
- if (!branch || !branch.getMutableSuccessorOperands(it.getSuccessorIndex()))
+ if (!isa<BranchOpInterface>((*it)->getTerminator()))
return false;
}
return true;
@@ -631,7 +629,7 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
predIt != predE; ++predIt) {
auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
unsigned succIndex = predIt.getSuccessorIndex();
- branch.getMutableSuccessorOperands(succIndex)->append(
+ branch.getSuccessorOperands(succIndex).append(
newArguments[clusterIndex]);
}
};
diff --git a/mlir/test/Transforms/sccp.mlir b/mlir/test/Transforms/sccp.mlir
index 4879ee8c54c40..a77fbe7a61a8c 100644
--- a/mlir/test/Transforms/sccp.mlir
+++ b/mlir/test/Transforms/sccp.mlir
@@ -198,3 +198,21 @@ func @recheck_executable_edge(%cond0: i1) -> (i1, i1) {
// CHECK: return %[[X]], %[[Y]]
return %x, %y : i1, i1
}
+
+// CHECK-LABEL: func @simple_produced_operand
+func @simple_produced_operand() -> (i32, i32) {
+ // CHECK: %[[ONE:.*]] = arith.constant 1
+ %1 = arith.constant 1 : i32
+ "test.internal_br"(%1) [^bb1, ^bb2] {
+ operand_segment_sizes = dense<[0, 1]> : vector<2 x i32>
+ } : (i32) -> ()
+
+^bb1:
+ cf.br ^bb2(%1, %1 : i32, i32)
+
+^bb2(%arg1 : i32, %arg2 : i32):
+ // CHECK: ^bb2(%[[ARG:.*]]: i32, %{{.*}}: i32):
+ // CHECK: return %[[ARG]], %[[ONE]] : i32, i32
+
+ return %arg1, %arg2 : i32, i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 73119805fcdf0..1f496ee2b09e3 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -335,22 +335,31 @@ TestDialect::getOperationPrinter(Operation *op) const {
// TestBranchOp
//===----------------------------------------------------------------------===//
-Optional<MutableOperandRange>
-TestBranchOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
- return getTargetOperandsMutable();
+ return SuccessorOperands(getTargetOperandsMutable());
}
//===----------------------------------------------------------------------===//
// TestProducingBranchOp
//===----------------------------------------------------------------------===//
-Optional<MutableOperandRange>
-TestProducingBranchOp::getMutableSuccessorOperands(unsigned index) {
+SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
assert(index <= 1 && "invalid successor index");
if (index == 1)
- return getFirstOperandsMutable();
- return getSecondOperandsMutable();
+ return SuccessorOperands(getFirstOperandsMutable());
+ return SuccessorOperands(getSecondOperandsMutable());
+}
+
+//===----------------------------------------------------------------------===//
+// TestProducingBranchOp
+//===----------------------------------------------------------------------===//
+
+SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
+ assert(index <= 1 && "invalid successor index");
+ if (index == 0)
+ return SuccessorOperands(0, getSuccessOperandsMutable());
+ return SuccessorOperands(1, getErrorOperandsMutable());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9902b57323ca3..bccca927725e0 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -642,6 +642,17 @@ def TestProducingBranchOp : TEST_Op<"producing_br",
let successors = (successor AnySuccessor:$first,AnySuccessor:$second);
}
+// Produces an error value on the error path
+def TestInternalBranchOp : TEST_Op<"internal_br",
+ [DeclareOpInterfaceMethods<BranchOpInterface>, Terminator,
+ AttrSizedOperandSegments]> {
+
+ let arguments = (ins Variadic<AnyType>:$successOperands,
+ Variadic<AnyType>:$errorOperands);
+
+ let successors = (successor AnySuccessor:$successPath, AnySuccessor:$errorPath);
+}
+
def AttrSizedOperandOp : TEST_Op<"attr_sized_operands",
[AttrSizedOperandSegments]> {
let arguments = (ins
More information about the Mlir-commits
mailing list