[Mlir-commits] [mlir] [mlir][SCF] scf.for: Consistent API around `initArgs` (PR #66512)
Matthias Springer
llvmlistbot at llvm.org
Mon Sep 18 00:08:19 PDT 2023
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/66512
>From 9d1ba8cce84366113a4a72eb5c36075b0b74996d Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 18 Sep 2023 09:07:37 +0200
Subject: [PATCH] [mlir][SCF] scf.for: Consistent API around `initArgs` API
* Always use the auto-generated `getInitArgs` function. Remove the hand-written `getInitOperands` duplicate.
* Remove `hasIterOperands` and `getNumIterOperands`. The names were inconsistent because the "arg" is called `initArgs` in TableGen. Use `getInitArgs().size()` instead.
* Fix verification around ops with no results.
BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 11 -----
.../SCFToControlFlow/SCFToControlFlow.cpp | 3 +-
.../Conversion/VectorToGPU/VectorToGPU.cpp | 14 +++---
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 2 +-
mlir/lib/Dialect/SCF/IR/SCF.cpp | 45 ++++++++-----------
.../SCF/Transforms/LoopCanonicalization.cpp | 21 +++++----
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 6 +--
.../Vector/Transforms/VectorDistribute.cpp | 2 +-
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 2 +-
mlir/test/Dialect/SCF/invalid.mlir | 13 ++++++
10 files changed, 55 insertions(+), 64 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 232e6b0bf4ed772..6d8aaf64e3263b9 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -250,9 +250,6 @@ def ForOp : SCF_Op<"for",
"expected an index less than the number of region iter args");
return getBody()->getArguments().drop_front(getNumInductionVars())[index];
}
- Operation::operand_range getIterOperands() {
- return getOperands().drop_front(getNumControlOperands());
- }
MutableArrayRef<OpOperand> getIterOpOperands() {
return
getOperation()->getOpOperands().drop_front(getNumControlOperands());
@@ -273,14 +270,6 @@ def ForOp : SCF_Op<"for",
}
/// Number of operands controlling the loop: lb, ub, step
unsigned getNumControlOperands() { return 3; }
- /// Does the operation hold operands for loop-carried values
- bool hasIterOperands() {
- return getOperation()->getNumOperands() > getNumControlOperands();
- }
- /// Get Number of loop-carried values
- unsigned getNumIterOperands() {
- return getOperation()->getNumOperands() - getNumControlOperands();
- }
/// Get the iter arg number for an operand. If it isnt an iter arg
/// operand return std::nullopt.
std::optional<unsigned> getIterArgNumberForOpOperand(OpOperand &opOperand) {
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index f5face5929916ae..c9b45fd4a7957b8 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -361,8 +361,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
// of the loop operation.
SmallVector<Value, 8> destOperands;
destOperands.push_back(lowerBound);
- auto iterOperands = forOp.getIterOperands();
- destOperands.append(iterOperands.begin(), iterOperands.end());
+ llvm::append_range(destOperands, forOp.getInitArgs());
rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands);
// With the body block done, we can fill in the condition block.
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 3089e917d0eed9c..c8871c945cbe759 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -1106,14 +1106,14 @@ convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
// updated and needs to be updated separatly for the loop to be correct.
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
scf::ForOp loop,
- ValueRange newIterOperands) {
+ ValueRange newInitArgs) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(loop);
// Create a new loop before the existing one, with the extra operands.
rewriter.setInsertionPoint(loop);
- auto operands = llvm::to_vector<4>(loop.getIterOperands());
- operands.append(newIterOperands.begin(), newIterOperands.end());
+ auto operands = llvm::to_vector<4>(loop.getInitArgs());
+ llvm::append_range(operands, newInitArgs);
scf::ForOp newLoop = rewriter.create<scf::ForOp>(
loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
operands);
@@ -1122,7 +1122,7 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
newLoop.getLoopBody().getBlocks().splice(
newLoop.getLoopBody().getBlocks().begin(),
loop.getLoopBody().getBlocks());
- for (Value operand : newIterOperands)
+ for (Value operand : newInitArgs)
newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
@@ -1144,14 +1144,14 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
SmallVector<Value> newOperands;
SmallVector<std::pair<size_t, size_t>> argMapping;
- for (const auto &operand : llvm::enumerate(op.getIterOperands())) {
+ for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
auto it = valueMapping.find(operand.value());
if (it == valueMapping.end()) {
LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
continue;
}
argMapping.push_back(std::make_pair(
- operand.index(), op.getNumIterOperands() + newOperands.size()));
+ operand.index(), op.getInitArgs().size() + newOperands.size()));
newOperands.push_back(it->second);
}
@@ -1183,7 +1183,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
continue;
// Replace the yield of old value with the for op argument to make it easier
// to remove the dead code.
- yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()];
+ yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
yieldOperands.push_back(it->second);
}
rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index cf3fd4ba0a0b5dc..ae0461965c4785c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -225,7 +225,7 @@ static void getProducerOfTensor(Value tensor, OpResult &opResult) {
}
if (auto blockArg = dyn_cast<BlockArgument>(tensor)) {
if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
- tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber());
+ tensor = forOp.getInitArgs()[blockArg.getArgNumber()];
continue;
}
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index ce413b283730407..5565aefbad18db5 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -343,15 +343,11 @@ LogicalResult ForOp::verify() {
if (matchPattern(getStep(), m_Constant(&step)) && step.getInt() <= 0)
return emitOpError("constant step operand must be positive");
- auto opNumResults = getNumResults();
- if (opNumResults == 0)
- return success();
- // If ForOp defines values, check that the number and types of
- // the defined values match ForOp initial iter operands and backedge
- // basic block arguments.
- if (getNumIterOperands() != opNumResults)
+ // Check that the number of init args and op results is the same.
+ if (getInitArgs().size() != getNumResults())
return emitOpError(
"mismatch in number of loop-carried values and defined values");
+
return success();
}
@@ -362,19 +358,15 @@ LogicalResult ForOp::verifyRegions() {
return emitOpError(
"expected induction variable to be same type as bounds and step");
- auto opNumResults = getNumResults();
- if (opNumResults == 0)
- return success();
-
- if (getNumRegionIterArgs() != opNumResults)
+ if (getNumRegionIterArgs() != getNumResults())
return emitOpError(
"mismatch in number of basic block args and defined values");
- auto iterOperands = getIterOperands();
+ auto initArgs = getInitArgs();
auto iterArgs = getRegionIterArgs();
auto opResults = getResults();
unsigned i = 0;
- for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
+ for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
if (std::get<0>(e).getType() != std::get<2>(e).getType())
return emitOpError() << "types mismatch between " << i
<< "th iter operand and defined value";
@@ -419,7 +411,7 @@ LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
// iter_args.
SmallVector<Value> bbArgReplacements;
bbArgReplacements.push_back(getLowerBound());
- bbArgReplacements.append(getIterOperands().begin(), getIterOperands().end());
+ llvm::append_range(bbArgReplacements, getInitArgs());
// Move the loop body operations to the loop's containing block.
rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
@@ -456,16 +448,15 @@ void ForOp::print(OpAsmPrinter &p) {
p << " " << getInductionVar() << " = " << getLowerBound() << " to "
<< getUpperBound() << " step " << getStep();
- printInitializationList(p, getRegionIterArgs(), getIterOperands(),
- " iter_args");
- if (!getIterOperands().empty())
- p << " -> (" << getIterOperands().getTypes() << ')';
+ printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
+ if (!getInitArgs().empty())
+ p << " -> (" << getInitArgs().getTypes() << ')';
p << ' ';
if (Type t = getInductionVar().getType(); !t.isIndex())
p << " : " << t << ' ';
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/hasIterOperands());
+ /*printBlockTerminators=*/!getInitArgs().empty());
p.printOptionalAttrDict((*this)->getAttrs());
}
@@ -751,12 +742,12 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
keepMask.reserve(yieldOp.getNumOperands());
SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
newResultValues;
- newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands());
+ newBlockTransferArgs.reserve(1 + forOp.getInitArgs().size());
newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
- newIterArgs.reserve(forOp.getNumIterOperands());
+ newIterArgs.reserve(forOp.getInitArgs().size());
newYieldValues.reserve(yieldOp.getNumOperands());
newResultValues.reserve(forOp.getNumResults());
- for (auto it : llvm::zip(forOp.getIterOperands(), // iter from outside
+ for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside
forOp.getRegionIterArgs(), // iter inside region
forOp.getResults(), // op results
yieldOp.getOperands() // iter yield
@@ -876,7 +867,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
// If the upper bound is the same as the lower bound, the loop does not
// iterate, just remove it.
if (op.getLowerBound() == op.getUpperBound()) {
- rewriter.replaceOp(op, op.getIterOperands());
+ rewriter.replaceOp(op, op.getInitArgs());
return success();
}
@@ -887,7 +878,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
// If the loop is known to have 0 iterations, remove it.
if (*diff <= 0) {
- rewriter.replaceOp(op, op.getIterOperands());
+ rewriter.replaceOp(op, op.getInitArgs());
return success();
}
@@ -900,9 +891,9 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
llvm::APInt stepValue = *maybeStepValue;
if (stepValue.sge(*diff)) {
SmallVector<Value, 4> blockArgs;
- blockArgs.reserve(op.getNumIterOperands() + 1);
+ blockArgs.reserve(op.getInitArgs().size() + 1);
blockArgs.push_back(op.getLowerBound());
- llvm::append_range(blockArgs, op.getIterOperands());
+ llvm::append_range(blockArgs, op.getInitArgs());
replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
return success();
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 1da10ddd6371f42..0cd19fbefa8ef98 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -48,16 +48,15 @@ static bool isShapePreserving(ForOp forOp, int64_t arg) {
return false;
using tensor::InsertSliceOp;
- value =
- llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
- .template Case<InsertSliceOp>(
- [&](InsertSliceOp op) { return op.getDest(); })
- .template Case<ForOp>([&](ForOp forOp) {
- return isShapePreserving(forOp, opResult.getResultNumber())
- ? forOp.getIterOperands()[opResult.getResultNumber()]
- : Value();
- })
- .Default([&](auto op) { return Value(); });
+ value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
+ .template Case<InsertSliceOp>(
+ [&](InsertSliceOp op) { return op.getDest(); })
+ .template Case<ForOp>([&](ForOp forOp) {
+ return isShapePreserving(forOp, opResult.getResultNumber())
+ ? forOp.getInitArgs()[opResult.getResultNumber()]
+ : Value();
+ })
+ .Default([&](auto op) { return Value(); });
}
return false;
}
@@ -144,7 +143,7 @@ struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
if (!isShapePreserving(forOp, resultNumber))
return failure();
rewriter.updateRootInPlace(dimOp, [&]() {
- dimOp.getSourceMutable().assign(forOp.getIterOperands()[resultNumber]);
+ dimOp.getSourceMutable().assign(forOp.getInitArgs()[resultNumber]);
});
return success();
}
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 9ac751f1915ab14..222a9aa395c4f09 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -46,8 +46,8 @@ mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
// Create a new loop before the existing one, with the extra operands.
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(loop);
- auto operands = llvm::to_vector(loop.getIterOperands());
- operands.append(newIterOperands.begin(), newIterOperands.end());
+ auto operands = llvm::to_vector(loop.getInitArgs());
+ llvm::append_range(operands, newIterOperands);
scf::ForOp newLoop = builder.create<scf::ForOp>(
loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
operands, [](OpBuilder &, Location, Value, ValueRange) {});
@@ -515,7 +515,7 @@ LogicalResult mlir::loopUnrollByFactor(
std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
}
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
- epilogueForOp.getNumIterOperands(), results);
+ epilogueForOp.getInitArgs().size(), results);
(void)epilogueForOp.promoteIfSingleIteration(rewriter);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 05b5ff09321489f..2a50947e976dffb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1491,7 +1491,7 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto forResult = cast<OpResult>(yieldOperand.get());
newOperands.push_back(
newWarpOp.getResult(yieldOperand.getOperandNumber()));
- yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
+ yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
resultIdx.push_back(yieldOperand.getOperandNumber());
}
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 91a4db9cb8be8ec..832dd8f2013fa4d 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -493,7 +493,7 @@ static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
raw_indented_ostream &os = emitter.ostream();
- OperandRange operands = forOp.getIterOperands();
+ OperandRange operands = forOp.getInitArgs();
Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
Operation::result_range results = forOp.getResults();
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 0cf587af42637c9..f6044ad10829227 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -83,6 +83,19 @@ func.func @loop_for_single_index_argument(%arg0: index) {
// -----
+func.func @not_enough_loop_results(%arg0: index, %init: f32) {
+ // expected-error @below{{mismatch in number of loop-carried values and defined values}}
+ "scf.for"(%arg0, %arg0, %arg0, %init) (
+ {
+ ^bb0(%i0 : index, %iter: f32):
+ scf.yield %iter : f32
+ }
+ ) : (index, index, index, f32) -> ()
+ return
+}
+
+// -----
+
func.func @loop_if_not_i1(%arg0: index) {
// expected-error at +1 {{operand #0 must be 1-bit signless integer}}
"scf.if"(%arg0) ({}, {}) : (index) -> ()
More information about the Mlir-commits
mailing list