[Mlir-commits] [mlir] [mlir][SCF] `ForOp`: Remove `getIterArgNumberForOpOperand` (PR #66629)
Matthias Springer
llvmlistbot at llvm.org
Mon Sep 18 03:36:21 PDT 2023
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/66629
This function was inconsistent with the remaining API because it accepted `OpOperand &` that do not belong to the op. All the other functions assert. This helper function is also not really necessary, as the iter_arg number is identical to the result number.
Depends on #66622. Review only the top commit.
>From 5046dd723deb659288cbfef0242f1517d0e765ba Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 18 Sep 2023 11:52:06 +0200
Subject: [PATCH 1/2] [mlir][IR] Change `MutableArrayRange` to enumerate
`OpOperand &`
In line with #66515, change `MutableArrayRange::begin`/`end` to enumerate `OpOperand &` instead of `Value`. Also a remove `ForOp::getIterOpOperands`/`setIterArg`, which are now redundant.
Note: `MutableOperandRange` cannot be made a derived class of `indexed_accessor_range_base` (like `OperandRange`), because `MutableOperandRange::assign` can change the number of operands in the range.
---
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 7 -----
mlir/include/mlir/IR/ValueRange.h | 10 ++-----
.../OwnershipBasedBufferDeallocation.cpp | 6 +++-
.../BufferDeallocationOpInterfaceImpl.cpp | 10 +++----
mlir/lib/Dialect/SCF/IR/SCF.cpp | 4 +--
.../BufferizableOpInterfaceImpl.cpp | 6 ++--
.../SCF/Transforms/TileUsingInterface.cpp | 16 +++++------
mlir/lib/IR/OperationSupport.cpp | 8 ++++++
mlir/lib/Transforms/Utils/CFGToSCF.cpp | 28 +++++++++++++------
9 files changed, 53 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 6d8aaf64e3263b9..8a9ce949a750d43 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -250,17 +250,10 @@ def ForOp : SCF_Op<"for",
"expected an index less than the number of region iter args");
return getBody()->getArguments().drop_front(getNumInductionVars())[index];
}
- MutableArrayRef<OpOperand> getIterOpOperands() {
- return
- getOperation()->getOpOperands().drop_front(getNumControlOperands());
- }
void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
void setStep(Value step) { getOperation()->setOperand(2, step); }
- void setIterArg(unsigned iterArgNum, Value iterArgValue) {
- getOperation()->setOperand(iterArgNum + getNumControlOperands(), iterArgValue);
- }
/// Number of induction variables, always 1 for scf::ForOp.
unsigned getNumInductionVars() { return 1; }
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index f1a1f1841f179e7..9c11178f9cd9cae 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -165,13 +165,9 @@ class MutableOperandRange {
/// Returns the OpOperand at the given index.
OpOperand &operator[](unsigned index) const;
- OperandRange::iterator begin() const {
- return static_cast<OperandRange>(*this).begin();
- }
-
- OperandRange::iterator end() const {
- return static_cast<OperandRange>(*this).end();
- }
+ /// Iterators enumerate OpOperands.
+ MutableArrayRef<OpOperand>::iterator begin() const;
+ MutableArrayRef<OpOperand>::iterator end() const;
private:
/// Update the length of this range to the one provided.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index 43ba11cf132cb92..09d30835828084d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -47,6 +47,10 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
+static bool isMemrefOperand(OpOperand &operand) {
+ return isMemref(operand.get());
+}
+
//===----------------------------------------------------------------------===//
// Backedges analysis
//===----------------------------------------------------------------------===//
@@ -937,7 +941,7 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
// Add an additional operand for every MemRef for the ownership indicator.
if (!funcWithoutDynamicOwnership) {
- unsigned numMemRefs = llvm::count_if(operands, isMemref);
+ unsigned numMemRefs = llvm::count_if(operands, isMemrefOperand);
SmallVector<Value> newOperands{OperandRange(operands)};
auto ownershipValues =
deallocOp.getUpdatedConditions().take_front(numMemRefs);
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
index e847e946eef1b5d..9423af2542690d9 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -96,12 +96,12 @@ struct CondBranchOpInterface
mapping[retained] = ownership;
}
SmallVector<Value> replacements, ownerships;
- for (Value operand : destOperands) {
- replacements.push_back(operand);
- if (isMemref(operand)) {
- assert(mapping.contains(operand) &&
+ for (OpOperand &operand : destOperands) {
+ replacements.push_back(operand.get());
+ if (isMemref(operand.get())) {
+ assert(mapping.contains(operand.get()) &&
"Should be contained at this point");
- ownerships.push_back(mapping[operand]);
+ ownerships.push_back(mapping[operand.get()]);
}
}
replacements.append(ownerships);
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5565aefbad18db5..2a760c76d2f6867 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -932,7 +932,7 @@ replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
assert(operand.get().getType() != replacement.getType() &&
"Expected a different type");
SmallVector<Value> newIterOperands;
- for (OpOperand &opOperand : forOp.getIterOpOperands()) {
+ for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
newIterOperands.push_back(replacement);
continue;
@@ -1015,7 +1015,7 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
LogicalResult matchAndRewrite(ForOp op,
PatternRewriter &rewriter) const override {
- for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
+ for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
OpOperand &iterOpOperand = std::get<0>(it);
auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
if (!incomingCast ||
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 11cfefed890c669..8c04a8887013c8f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -332,7 +332,7 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
/// Helper function for loop bufferization. Return the bufferized values of the
/// given OpOperands. If an operand is not a tensor, return the original value.
static FailureOr<SmallVector<Value>>
-getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
+getBuffers(RewriterBase &rewriter, MutableOperandRange operands,
const BufferizationOptions &options) {
SmallVector<Value> result;
for (OpOperand &opOperand : operands) {
@@ -606,7 +606,7 @@ struct ForOpInterface
// The new memref init_args of the loop.
FailureOr<SmallVector<Value>> maybeInitArgs =
- getBuffers(rewriter, forOp.getIterOpOperands(), options);
+ getBuffers(rewriter, forOp.getInitArgsMutable(), options);
if (failed(maybeInitArgs))
return failure();
SmallVector<Value> initArgs = *maybeInitArgs;
@@ -825,7 +825,7 @@ struct WhileOpInterface
// The new memref init_args of the loop.
FailureOr<SmallVector<Value>> maybeInitArgs =
- getBuffers(rewriter, whileOp->getOpOperands(), options);
+ getBuffers(rewriter, whileOp.getInitsMutable(), options);
if (failed(maybeInitArgs))
return failure();
SmallVector<Value> initArgs = *maybeInitArgs;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1ce25565edcaf61..ceec0756e421ffd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -508,7 +508,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
MutableArrayRef<scf::ForOp> loops) {
// 1. Get the producer of the source (potentially walking through
// `iter_args` of nested `scf.for`)
- auto [fusableProducer, destinationIterArg] =
+ auto [fusableProducer, destinationInitArg] =
getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
loops);
if (!fusableProducer)
@@ -575,17 +575,15 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
// TODO: This can be modeled better if the `DestinationStyleOpInterface`.
// Update to use that when it does become available.
scf::ForOp outerMostLoop = loops.front();
- std::optional<unsigned> iterArgNumber;
- if (destinationIterArg) {
- iterArgNumber =
- outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value());
- }
- if (iterArgNumber) {
+ if (destinationInitArg &&
+ (*destinationInitArg)->getOwner() == outerMostLoop) {
+ std::optional<unsigned> iterArgNumber =
+ outerMostLoop.getIterArgNumberForOpOperand(**destinationInitArg);
int64_t resultNumber = fusableProducer.getResultNumber();
if (auto dstOp =
dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
- outerMostLoop.setIterArg(iterArgNumber.value(),
- dstOp.getTiedOpOperand(fusableProducer)->get());
+ (*destinationInitArg)
+ ->set(dstOp.getTiedOpOperand(fusableProducer)->get());
}
for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 7b17e231ce1065f..b0c50f3d6e29855 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -522,6 +522,14 @@ OpOperand &MutableOperandRange::operator[](unsigned index) const {
return owner->getOpOperand(start + index);
}
+MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const {
+ return owner->getOpOperands().slice(start, length).begin();
+}
+
+MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const {
+ return owner->getOpOperands().slice(start, length).end();
+}
+
//===----------------------------------------------------------------------===//
// MutableOperandRangeRange
diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
index 9aab89ed7553600..e7bf6628ccbd7e0 100644
--- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp
+++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
@@ -137,6 +137,13 @@ getMutableSuccessorOperands(Block *block, unsigned successorIndex) {
return succOps.getMutableForwardedOperands();
}
+/// Return the operand range used to transfer operands from `block` to its
+/// successor with the given index.
+static OperandRange getSuccessorOperands(Block *block,
+ unsigned successorIndex) {
+ return getMutableSuccessorOperands(block, successorIndex);
+}
+
/// Appends all the block arguments from `other` to the block arguments of
/// `block`, copying their types and locations.
static void addBlockArgumentsFromOther(Block *block, Block *other) {
@@ -175,8 +182,14 @@ class Edge {
/// Returns the arguments of this edge that are passed to the block arguments
/// of the successor.
- MutableOperandRange getSuccessorOperands() const {
- return getMutableSuccessorOperands(fromBlock, successorIndex);
+ MutableOperandRange getMutableSuccessorOperands() const {
+ return ::getMutableSuccessorOperands(fromBlock, successorIndex);
+ }
+
+ /// Returns the arguments of this edge that are passed to the block arguments
+ /// of the successor.
+ OperandRange getSuccessorOperands() const {
+ return ::getSuccessorOperands(fromBlock, successorIndex);
}
};
@@ -262,7 +275,7 @@ class EdgeMultiplexer {
assert(result != blockArgMapping.end() &&
"Edge was not originally passed to `create` method.");
- MutableOperandRange successorOperands = edge.getSuccessorOperands();
+ MutableOperandRange successorOperands = edge.getMutableSuccessorOperands();
// Extra arguments are always appended at the end of the block arguments.
unsigned extraArgsBeginIndex =
@@ -666,7 +679,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
// invalidated when mutating the operands through a different
// `MutableOperandRange` of the same operation.
SmallVector<Value> loopHeaderSuccessorOperands =
- llvm::to_vector(getMutableSuccessorOperands(latch, loopHeaderIndex));
+ llvm::to_vector(getSuccessorOperands(latch, loopHeaderIndex));
// Add all values used in the next iteration to the exit block. Replace
// any uses that are outside the loop with the newly created exit block.
@@ -742,7 +755,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
loopHeaderSuccessorOperands.push_back(argument);
for (Edge edge : successorEdges(latch))
- edge.getSuccessorOperands().append(argument);
+ edge.getMutableSuccessorOperands().append(argument);
}
use.set(blockArgument);
@@ -939,9 +952,8 @@ static FailureOr<SmallVector<Block *>> transformToStructuredCFBranches(
if (regionEntry->getNumSuccessors() == 1) {
// Single successor we can just splice together.
Block *successor = regionEntry->getSuccessor(0);
- for (auto &&[oldValue, newValue] :
- llvm::zip(successor->getArguments(),
- getMutableSuccessorOperands(regionEntry, 0)))
+ for (auto &&[oldValue, newValue] : llvm::zip(
+ successor->getArguments(), getSuccessorOperands(regionEntry, 0)))
oldValue.replaceAllUsesWith(newValue);
regionEntry->getTerminator()->erase();
>From 148bb36558497e81c3d09c132d21309f18dd1acc Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 18 Sep 2023 12:34:21 +0200
Subject: [PATCH 2/2] [mlir][SCF] `ForOp`: Remove
`getIterArgNumberForOpOperand`
This function was inconsistent with the remaining API because it accepted `OpOperand &` that do not belong to the op. All the other functions assert. This helper function is also not really necessary, as the iter_arg number is identical to the result number.
---
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 10 ---------
.../Linalg/Transforms/HoistPadding.cpp | 22 ++++++++-----------
.../SCF/Transforms/TileUsingInterface.cpp | 7 +++---
3 files changed, 13 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 8a9ce949a750d43..89c1a06412947b2 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -263,16 +263,6 @@ def ForOp : SCF_Op<"for",
}
/// Number of operands controlling the loop: lb, ub, step
unsigned getNumControlOperands() { return 3; }
- /// Get the iter arg number for an operand. If it isnt an iter arg
- /// operand return std::nullopt.
- std::optional<unsigned> getIterArgNumberForOpOperand(OpOperand &opOperand) {
- if (opOperand.getOwner() != getOperation())
- return std::nullopt;
- unsigned operandNumber = opOperand.getOperandNumber();
- if (operandNumber < getNumControlOperands())
- return std::nullopt;
- return operandNumber - getNumControlOperands();
- }
/// Get the region iter arg that corresponds to an OpOperand.
/// This helper prevents internal op implementation detail leakage to
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 21bc0554e717692..a9debb7bbc489a4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -810,13 +810,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
- std::optional<unsigned> maybeOperandNumber =
- forOp.getIterArgNumberForOpOperand(*pUse);
- assert(maybeOperandNumber.has_value() && "expected a proper iter arg number");
-
- int64_t operandNumber = maybeOperandNumber.value();
+ unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
auto yieldOp = cast<scf::YieldOp>(forOp.getBody(0)->getTerminator());
- auto yieldingExtractSliceOp = yieldOp->getOperand(operandNumber)
+ auto yieldingExtractSliceOp = yieldOp->getOperand(iterArgNumber)
.getDefiningOp<tensor::ExtractSliceOp>();
if (!yieldingExtractSliceOp)
return tensor::ExtractSliceOp();
@@ -829,9 +825,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
return tensor::ExtractSliceOp();
SmallVector<Value> initArgs = forOp.getInitArgs();
- initArgs[operandNumber] = hoistedPackedTensor;
+ initArgs[iterArgNumber] = hoistedPackedTensor;
SmallVector<Value> yieldOperands = yieldOp.getOperands();
- yieldOperands[operandNumber] = yieldingExtractSliceOp.getSource();
+ yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource();
int64_t numOriginalForOpResults = initArgs.size();
LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults
@@ -844,7 +840,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
hoistedPackedTensor.getLoc(), hoistedPackedTensor,
outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(),
outerSliceOp.getMixedStrides());
- rewriter.replaceAllUsesWith(forOp.getResult(operandNumber), extracted);
+ rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted);
}
scf::ForOp newForOp =
replaceLoopWithNewYields(rewriter, forOp, initArgs, yieldOperands);
@@ -853,20 +849,20 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
<< "\n");
LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n");
LLVM_DEBUG(DBGS() << "with result #"
- << numOriginalForOpResults + operandNumber
+ << numOriginalForOpResults + iterArgNumber
<< " of forOp, giving us: " << extracted << "\n");
rewriter.startRootUpdate(extracted);
extracted.getSourceMutable().assign(
- newForOp.getResult(numOriginalForOpResults + operandNumber));
+ newForOp.getResult(numOriginalForOpResults + iterArgNumber));
rewriter.finalizeRootUpdate(extracted);
LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting
<< "\n");
LLVM_DEBUG(DBGS() << "with region iter arg #"
- << numOriginalForOpResults + operandNumber << "\n");
+ << numOriginalForOpResults + iterArgNumber << "\n");
rewriter.replaceAllUsesWith(
paddedValueBeforeHoisting,
- newForOp.getRegionIterArg(numOriginalForOpResults + operandNumber));
+ newForOp.getRegionIterArg(numOriginalForOpResults + iterArgNumber));
return extracted;
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index ceec0756e421ffd..d7b986e052ae575 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -577,8 +577,9 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
scf::ForOp outerMostLoop = loops.front();
if (destinationInitArg &&
(*destinationInitArg)->getOwner() == outerMostLoop) {
- std::optional<unsigned> iterArgNumber =
- outerMostLoop.getIterArgNumberForOpOperand(**destinationInitArg);
+ unsigned iterArgNumber =
+ outerMostLoop.getResultForOpOperand(**destinationInitArg)
+ .getResultNumber();
int64_t resultNumber = fusableProducer.getResultNumber();
if (auto dstOp =
dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
@@ -592,7 +593,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
scf::ForOp innerMostLoop = loops.back();
updateDestinationOperandsForTiledOp(
rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
- innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
+ innerMostLoop.getRegionIterArgs()[iterArgNumber]);
}
}
return scf::SCFFuseProducerOfSliceResult{fusableProducer,
More information about the Mlir-commits
mailing list