[Mlir-commits] [mlir] [mlir][IR] Change `MutableArrayRange` to enumerate `OpOperand &` (PR #66622)
Matthias Springer
llvmlistbot at llvm.org
Mon Sep 18 02:55:55 PDT 2023
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/66622
In line with #66515, change `MutableArrayRange::begin`/`end` to enumerate `OpOperand &` instead of `Value`. Also remove `ForOp::getIterOpOperands`/`setIterArg`, which are now redundant.
Note: `MutableOperandRange` cannot be made a derived class of `indexed_accessor_range_base` (like `OperandRange`), because `MutableOperandRange::assign` can change the number of operands in the range.
>From 5046dd723deb659288cbfef0242f1517d0e765ba Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 18 Sep 2023 11:52:06 +0200
Subject: [PATCH] [mlir][IR] Change `MutableArrayRange` to enumerate `OpOperand
&`
In line with #66515, change `MutableArrayRange::begin`/`end` to enumerate `OpOperand &` instead of `Value`. Also a remove `ForOp::getIterOpOperands`/`setIterArg`, which are now redundant.
Note: `MutableOperandRange` cannot be made a derived class of `indexed_accessor_range_base` (like `OperandRange`), because `MutableOperandRange::assign` can change the number of operands in the range.
---
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 7 -----
mlir/include/mlir/IR/ValueRange.h | 10 ++-----
.../OwnershipBasedBufferDeallocation.cpp | 6 +++-
.../BufferDeallocationOpInterfaceImpl.cpp | 10 +++----
mlir/lib/Dialect/SCF/IR/SCF.cpp | 4 +--
.../BufferizableOpInterfaceImpl.cpp | 6 ++--
.../SCF/Transforms/TileUsingInterface.cpp | 16 +++++------
mlir/lib/IR/OperationSupport.cpp | 8 ++++++
mlir/lib/Transforms/Utils/CFGToSCF.cpp | 28 +++++++++++++------
9 files changed, 53 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 6d8aaf64e3263b9..8a9ce949a750d43 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -250,17 +250,10 @@ def ForOp : SCF_Op<"for",
"expected an index less than the number of region iter args");
return getBody()->getArguments().drop_front(getNumInductionVars())[index];
}
- MutableArrayRef<OpOperand> getIterOpOperands() {
- return
- getOperation()->getOpOperands().drop_front(getNumControlOperands());
- }
void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
void setStep(Value step) { getOperation()->setOperand(2, step); }
- void setIterArg(unsigned iterArgNum, Value iterArgValue) {
- getOperation()->setOperand(iterArgNum + getNumControlOperands(), iterArgValue);
- }
/// Number of induction variables, always 1 for scf::ForOp.
unsigned getNumInductionVars() { return 1; }
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index f1a1f1841f179e7..9c11178f9cd9cae 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -165,13 +165,9 @@ class MutableOperandRange {
/// Returns the OpOperand at the given index.
OpOperand &operator[](unsigned index) const;
- OperandRange::iterator begin() const {
- return static_cast<OperandRange>(*this).begin();
- }
-
- OperandRange::iterator end() const {
- return static_cast<OperandRange>(*this).end();
- }
+ /// Iterators enumerate OpOperands.
+ MutableArrayRef<OpOperand>::iterator begin() const;
+ MutableArrayRef<OpOperand>::iterator end() const;
private:
/// Update the length of this range to the one provided.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index 43ba11cf132cb92..09d30835828084d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -47,6 +47,10 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
+static bool isMemrefOperand(OpOperand &operand) {
+ return isMemref(operand.get());
+}
+
//===----------------------------------------------------------------------===//
// Backedges analysis
//===----------------------------------------------------------------------===//
@@ -937,7 +941,7 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
// Add an additional operand for every MemRef for the ownership indicator.
if (!funcWithoutDynamicOwnership) {
- unsigned numMemRefs = llvm::count_if(operands, isMemref);
+ unsigned numMemRefs = llvm::count_if(operands, isMemrefOperand);
SmallVector<Value> newOperands{OperandRange(operands)};
auto ownershipValues =
deallocOp.getUpdatedConditions().take_front(numMemRefs);
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
index e847e946eef1b5d..9423af2542690d9 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -96,12 +96,12 @@ struct CondBranchOpInterface
mapping[retained] = ownership;
}
SmallVector<Value> replacements, ownerships;
- for (Value operand : destOperands) {
- replacements.push_back(operand);
- if (isMemref(operand)) {
- assert(mapping.contains(operand) &&
+ for (OpOperand &operand : destOperands) {
+ replacements.push_back(operand.get());
+ if (isMemref(operand.get())) {
+ assert(mapping.contains(operand.get()) &&
"Should be contained at this point");
- ownerships.push_back(mapping[operand]);
+ ownerships.push_back(mapping[operand.get()]);
}
}
replacements.append(ownerships);
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5565aefbad18db5..2a760c76d2f6867 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -932,7 +932,7 @@ replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
assert(operand.get().getType() != replacement.getType() &&
"Expected a different type");
SmallVector<Value> newIterOperands;
- for (OpOperand &opOperand : forOp.getIterOpOperands()) {
+ for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
newIterOperands.push_back(replacement);
continue;
@@ -1015,7 +1015,7 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
LogicalResult matchAndRewrite(ForOp op,
PatternRewriter &rewriter) const override {
- for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
+ for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
OpOperand &iterOpOperand = std::get<0>(it);
auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
if (!incomingCast ||
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 11cfefed890c669..8c04a8887013c8f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -332,7 +332,7 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
/// Helper function for loop bufferization. Return the bufferized values of the
/// given OpOperands. If an operand is not a tensor, return the original value.
static FailureOr<SmallVector<Value>>
-getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
+getBuffers(RewriterBase &rewriter, MutableOperandRange operands,
const BufferizationOptions &options) {
SmallVector<Value> result;
for (OpOperand &opOperand : operands) {
@@ -606,7 +606,7 @@ struct ForOpInterface
// The new memref init_args of the loop.
FailureOr<SmallVector<Value>> maybeInitArgs =
- getBuffers(rewriter, forOp.getIterOpOperands(), options);
+ getBuffers(rewriter, forOp.getInitArgsMutable(), options);
if (failed(maybeInitArgs))
return failure();
SmallVector<Value> initArgs = *maybeInitArgs;
@@ -825,7 +825,7 @@ struct WhileOpInterface
// The new memref init_args of the loop.
FailureOr<SmallVector<Value>> maybeInitArgs =
- getBuffers(rewriter, whileOp->getOpOperands(), options);
+ getBuffers(rewriter, whileOp.getInitsMutable(), options);
if (failed(maybeInitArgs))
return failure();
SmallVector<Value> initArgs = *maybeInitArgs;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1ce25565edcaf61..ceec0756e421ffd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -508,7 +508,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
MutableArrayRef<scf::ForOp> loops) {
// 1. Get the producer of the source (potentially walking through
// `iter_args` of nested `scf.for`)
- auto [fusableProducer, destinationIterArg] =
+ auto [fusableProducer, destinationInitArg] =
getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
loops);
if (!fusableProducer)
@@ -575,17 +575,15 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
// TODO: This can be modeled better if the `DestinationStyleOpInterface`.
// Update to use that when it does become available.
scf::ForOp outerMostLoop = loops.front();
- std::optional<unsigned> iterArgNumber;
- if (destinationIterArg) {
- iterArgNumber =
- outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value());
- }
- if (iterArgNumber) {
+ if (destinationInitArg &&
+ (*destinationInitArg)->getOwner() == outerMostLoop) {
+ std::optional<unsigned> iterArgNumber =
+ outerMostLoop.getIterArgNumberForOpOperand(**destinationInitArg);
int64_t resultNumber = fusableProducer.getResultNumber();
if (auto dstOp =
dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
- outerMostLoop.setIterArg(iterArgNumber.value(),
- dstOp.getTiedOpOperand(fusableProducer)->get());
+ (*destinationInitArg)
+ ->set(dstOp.getTiedOpOperand(fusableProducer)->get());
}
for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 7b17e231ce1065f..b0c50f3d6e29855 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -522,6 +522,14 @@ OpOperand &MutableOperandRange::operator[](unsigned index) const {
return owner->getOpOperand(start + index);
}
+MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const {
+ return owner->getOpOperands().slice(start, length).begin();
+}
+
+MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const {
+ return owner->getOpOperands().slice(start, length).end();
+}
+
//===----------------------------------------------------------------------===//
// MutableOperandRangeRange
diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
index 9aab89ed7553600..e7bf6628ccbd7e0 100644
--- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp
+++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
@@ -137,6 +137,13 @@ getMutableSuccessorOperands(Block *block, unsigned successorIndex) {
return succOps.getMutableForwardedOperands();
}
+/// Return the operand range used to transfer operands from `block` to its
+/// successor with the given index.
+static OperandRange getSuccessorOperands(Block *block,
+ unsigned successorIndex) {
+ return getMutableSuccessorOperands(block, successorIndex);
+}
+
/// Appends all the block arguments from `other` to the block arguments of
/// `block`, copying their types and locations.
static void addBlockArgumentsFromOther(Block *block, Block *other) {
@@ -175,8 +182,14 @@ class Edge {
/// Returns the arguments of this edge that are passed to the block arguments
/// of the successor.
- MutableOperandRange getSuccessorOperands() const {
- return getMutableSuccessorOperands(fromBlock, successorIndex);
+ MutableOperandRange getMutableSuccessorOperands() const {
+ return ::getMutableSuccessorOperands(fromBlock, successorIndex);
+ }
+
+ /// Returns the arguments of this edge that are passed to the block arguments
+ /// of the successor.
+ OperandRange getSuccessorOperands() const {
+ return ::getSuccessorOperands(fromBlock, successorIndex);
}
};
@@ -262,7 +275,7 @@ class EdgeMultiplexer {
assert(result != blockArgMapping.end() &&
"Edge was not originally passed to `create` method.");
- MutableOperandRange successorOperands = edge.getSuccessorOperands();
+ MutableOperandRange successorOperands = edge.getMutableSuccessorOperands();
// Extra arguments are always appended at the end of the block arguments.
unsigned extraArgsBeginIndex =
@@ -666,7 +679,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
// invalidated when mutating the operands through a different
// `MutableOperandRange` of the same operation.
SmallVector<Value> loopHeaderSuccessorOperands =
- llvm::to_vector(getMutableSuccessorOperands(latch, loopHeaderIndex));
+ llvm::to_vector(getSuccessorOperands(latch, loopHeaderIndex));
// Add all values used in the next iteration to the exit block. Replace
// any uses that are outside the loop with the newly created exit block.
@@ -742,7 +755,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
loopHeaderSuccessorOperands.push_back(argument);
for (Edge edge : successorEdges(latch))
- edge.getSuccessorOperands().append(argument);
+ edge.getMutableSuccessorOperands().append(argument);
}
use.set(blockArgument);
@@ -939,9 +952,8 @@ static FailureOr<SmallVector<Block *>> transformToStructuredCFBranches(
if (regionEntry->getNumSuccessors() == 1) {
// Single successor we can just splice together.
Block *successor = regionEntry->getSuccessor(0);
- for (auto &&[oldValue, newValue] :
- llvm::zip(successor->getArguments(),
- getMutableSuccessorOperands(regionEntry, 0)))
+ for (auto &&[oldValue, newValue] : llvm::zip(
+ successor->getArguments(), getSuccessorOperands(regionEntry, 0)))
oldValue.replaceAllUsesWith(newValue);
regionEntry->getTerminator()->erase();
More information about the Mlir-commits
mailing list