[flang-commits] [flang] [mlir][Interfaces] `LoopLikeOpInterface`: Add helper to get yielded values (PR #67305)
Matthias Springer via flang-commits
flang-commits at lists.llvm.org
Mon Sep 25 03:30:57 PDT 2023
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/67305
>From 396a6f07aa9f94474f304c57f94ca98258abe83a Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 25 Sep 2023 12:28:53 +0200
Subject: [PATCH 1/2] [mlir][Interfaces] `LoopLikeOpInterface`: Add
`replaceWithAdditionalYields`
`affine::replaceForOpWithNewYields` and `replaceLoopWithNewYields` (for "scf.for") are now interface methods and additional loop-carried variables can now be added to "scf.for"/"affine.for" uniformly. (No more `TypeSwitch` needed.)
Note: `scf.while` and other loops with loop-carried variables can implement `replaceWithAdditionalYields`, but to keep this commit small, that is not done in this commit.
BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
.../mlir/Dialect/Affine/IR/AffineOps.h | 14 ---
.../mlir/Dialect/Affine/IR/AffineOps.td | 2 +-
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 3 +-
mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 42 +-------
.../mlir/Interfaces/LoopLikeInterface.h | 7 ++
.../mlir/Interfaces/LoopLikeInterface.td | 43 +++++++++
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 96 ++++++++++---------
.../Dialect/Affine/Utils/LoopFusionUtils.cpp | 20 ++--
mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp | 22 +++--
.../Linalg/Transforms/HoistPadding.cpp | 7 +-
.../Dialect/Linalg/Transforms/Hoisting.cpp | 53 +++-------
.../Linalg/Transforms/SubsetHoisting.cpp | 20 ++--
mlir/lib/Dialect/SCF/IR/SCF.cpp | 53 ++++++++++
.../SCF/Transforms/TileUsingInterface.cpp | 3 +-
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 91 +++---------------
.../scf-replace-with-new-yields.mlir | 3 -
mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp | 18 ++--
17 files changed, 241 insertions(+), 256 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 704c2704536d20a..56b4a609e62c001 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -490,20 +490,6 @@ void buildAffineLoopNest(OpBuilder &builder, Location loc, ValueRange lbs,
function_ref<void(OpBuilder &, Location, ValueRange)>
bodyBuilderFn = nullptr);
-/// Replace `loop` with a new loop where `newIterOperands` are appended with
-/// new initialization values and `newYieldedValues` are added as new yielded
-/// values. The returned ForOp has `newYieldedValues.size()` new result values.
-/// Additionally, if `replaceLoopResults` is true, all uses of
-/// `loop.getResults()` are replaced with the first `loop.getNumResults()`
-/// return values of the original loop respectively. The original loop is
-/// deleted and the new loop returned.
-/// Prerequisite: `newIterOperands.size() == newYieldedValues.size()`.
-AffineForOp replaceForOpWithNewYields(OpBuilder &b, AffineForOp loop,
- ValueRange newIterOperands,
- ValueRange newYieldedValues,
- ValueRange newIterArgs,
- bool replaceLoopResults = true);
-
/// AffineBound represents a lower or upper bound in the for operation.
/// This class does not own the underlying operands. Instead, it refers
/// to the operands stored in the AffineForOp. Its life span should not exceed
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 5a1baaf4e1611c8..d8ef0506d0822d7 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -120,7 +120,7 @@ def AffineForOp : Affine_Op<"for",
[AutomaticAllocationScope, ImplicitAffineTerminator, ConditionallySpeculatable,
RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
- "getSingleUpperBound"]>,
+ "getSingleUpperBound", "replaceWithAdditionalYields"]>,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>]> {
let summary = "for operation";
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 0c93989ca99a4eb..e62015b84888a18 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -121,7 +121,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
def ForOp : SCF_Op<"for",
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInits", "getSingleInductionVar", "getSingleLowerBound",
- "getSingleStep", "getSingleUpperBound", "promoteIfSingleIteration"]>,
+ "getSingleStep", "getSingleUpperBound", "promoteIfSingleIteration",
+ "replaceWithAdditionalYields"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
ConditionallySpeculatable,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index bde30c9c3528dbc..9bdd6eb833876f0 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -34,39 +34,6 @@ class CallOp;
class FuncOp;
} // namespace func
-/// Replace the `loop` with `newIterOperands` added as new initialization
-/// values. `newYieldValuesFn` is a callback that can be used to specify
-/// the additional values to be yielded by the loop. The number of
-/// values returned by the callback should match the number of new
-/// initialization values. This function
-/// - Moves (i.e. doesnt clone) operations from the `loop` to the newly created
-/// loop
-/// - Replaces the uses of `loop` with the new loop.
-/// - `loop` isnt erased, but is left in a "no-op" state where the body of the
-/// loop just yields the basic block arguments that correspond to the
-/// initialization values of a loop. The loop is dead after this method.
-/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the
-/// `newIterOperands` within the generated new loop are replaced
-/// with the corresponding `BlockArgument` in the loop body.
-using NewYieldValueFn = std::function<SmallVector<Value>(
- OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs)>;
-scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
- ValueRange newIterOperands,
- const NewYieldValueFn &newYieldValuesFn,
- bool replaceIterOperandsUsesInLoop = true);
-// Simpler API if the new yields are just a list of values that can be
-// determined ahead of time.
-inline scf::ForOp
-replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
- ValueRange newIterOperands, ValueRange newYields,
- bool replaceIterOperandsUsesInLoop = true) {
- auto fn = [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
- return SmallVector<Value>(newYields.begin(), newYields.end());
- };
- return replaceLoopWithNewYields(builder, loop, newIterOperands, fn,
- replaceIterOperandsUsesInLoop);
-}
-
/// Update a perfectly nested loop nest to yield new values from the innermost
/// loop and propagating it up through the loop nest. This function
/// - Expects `loopNest` to be a perfectly nested loop with outer most loop
@@ -82,11 +49,10 @@ replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the
/// `newIterOperands` within the generated new loop are replaced with the
/// corresponding `BlockArgument` in the loop body.
-SmallVector<scf::ForOp>
-replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
- ValueRange newIterOperands,
- const NewYieldValueFn &newYieldValueFn,
- bool replaceIterOperandsUsesInLoop = true);
+SmallVector<scf::ForOp> replaceLoopNestWithNewYields(
+ RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
+ ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
+ bool replaceIterOperandsUsesInLoop = true);
/// Outline a region with a single block into a new FuncOp.
/// Assumes the FuncOp result types is the type of the yielded operands of the
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index 9d81a61fac88566..0eebb984e5897ae 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -17,6 +17,13 @@
namespace mlir {
class RewriterBase;
+
+/// A function that returns the additional yielded values during
+/// `replaceWithAdditionalYields`. `newBbArgs` are the newly added region
+/// iter_args. This function should return as many values as there are block
+/// arguments in `newBbArgs`.
+using NewYieldValuesFn = std::function<SmallVector<Value>(
+ OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs)>;
} // namespace mlir
/// Include the generated interface declarations.
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index cb6b2f4ed4ae8b5..ded0a29292ff6f0 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -141,6 +141,31 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
return ::mlir::Block::BlockArgListType();
}]
>,
+ InterfaceMethod<[{
+ Append the specified additional "init" operands: replace this loop with
+ a new loop that has the additional init operands. The loop body of
+ this loop is moved over to the new loop.
+
+ `newInitOperands` specifies the additional "init" operands.
+ `newYieldValuesFn` is a function that returns the yielded values (which
+ can be computed based on the additional region iter_args). If
+ `replaceInitOperandUsesInLoop` is set, all uses of the additional init
+ operands inside of this loop are replaced with the corresponding, newly
+ added region iter_args.
+
+ Note: Loops that do not support init/iter_args should return "failure".
+ }],
+ /*retTy=*/"::mlir::FailureOr<::mlir::LoopLikeOpInterface>",
+ /*methodName=*/"replaceWithAdditionalYields",
+ /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
+ "::mlir::ValueRange":$newInitOperands,
+ "bool":$replaceInitOperandUsesInLoop,
+ "const ::mlir::NewYieldValuesFn &":$newYieldValuesFn),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::failure();
+ }]
+ >,
];
let extraClassDeclaration = [{
@@ -149,6 +174,24 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/// because the control flow graph is cyclic
static bool blockIsInLoop(Block *block);
}];
+
+ let extraSharedClassDeclaration = [{
+ /// Append the specified additional "init" operands: replace this loop with
+ /// a new loop that has the additional init operands. The loop body of this
+ /// loop is moved over to the new loop.
+ ///
+ /// The newly added region iter_args are yielded from the loop.
+ ::mlir::FailureOr<::mlir::LoopLikeOpInterface>
+ replaceWithAdditionalIterOperands(::mlir::RewriterBase &rewriter,
+ ::mlir::ValueRange newInitOperands,
+ bool replaceInitOperandUsesInLoop) {
+ return $_op.replaceWithAdditionalYields(
+ rewriter, newInitOperands, replaceInitOperandUsesInLoop,
+ [](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+ return SmallVector<Value>(newBBArgs);
+ });
+ }
+ }];
}
#endif // MLIR_INTERFACES_LOOPLIKEINTERFACE
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 9ecc568883a3b0b..6c060c90e24af82 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2575,6 +2575,58 @@ std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
}
+FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
+ RewriterBase &rewriter, ValueRange newInitOperands,
+ bool replaceInitOperandUsesInLoop,
+ const NewYieldValuesFn &newYieldValuesFn) {
+ // Create a new loop before the existing one, with the extra operands.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(getOperation());
+ auto inits = llvm::to_vector(getInits());
+ inits.append(newInitOperands.begin(), newInitOperands.end());
+ AffineForOp newLoop = rewriter.create<AffineForOp>(
+ getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
+ getUpperBoundOperands(), getUpperBoundMap(), getStep(), inits);
+
+ // Generate the new yield values and append them to the scf.yield operation.
+ auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
+ ArrayRef<BlockArgument> newIterArgs =
+ newLoop.getBody()->getArguments().take_back(newInitOperands.size());
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(yieldOp);
+ SmallVector<Value> newYieldedValues =
+ newYieldValuesFn(rewriter, getLoc(), newIterArgs);
+ assert(newInitOperands.size() == newYieldedValues.size() &&
+ "expected as many new yield values as new iter operands");
+ rewriter.updateRootInPlace(yieldOp, [&]() {
+ yieldOp.getOperandsMutable().append(newYieldedValues);
+ });
+ }
+
+ // Move the loop body to the new op.
+ rewriter.mergeBlocks(getBody(), newLoop.getBody(),
+ newLoop.getBody()->getArguments().take_front(
+ getBody()->getNumArguments()));
+
+ if (replaceInitOperandUsesInLoop) {
+ // Replace all uses of `newInitOperands` with the corresponding basic block
+ // arguments.
+ for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
+ rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
+ [&](OpOperand &use) {
+ Operation *user = use.getOwner();
+ return newLoop->isProperAncestor(user);
+ });
+ }
+ }
+
+ // Replace the old loop.
+ rewriter.replaceOp(getOperation(),
+ newLoop->getResults().take_front(getNumResults()));
+ return cast<LoopLikeOpInterface>(newLoop.getOperation());
+}
+
Speculation::Speculatability AffineForOp::getSpeculatability() {
// `affine.for (I = Start; I < End; I += 1)` terminates for all values of
// Start and End.
@@ -2725,50 +2777,6 @@ void mlir::affine::buildAffineLoopNest(
buildAffineLoopFromValues);
}
-AffineForOp mlir::affine::replaceForOpWithNewYields(OpBuilder &b,
- AffineForOp loop,
- ValueRange newIterOperands,
- ValueRange newYieldedValues,
- ValueRange newIterArgs,
- bool replaceLoopResults) {
- assert(newIterOperands.size() == newYieldedValues.size() &&
- "newIterOperands must be of the same size as newYieldedValues");
- // Create a new loop before the existing one, with the extra operands.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(loop);
- auto operands = llvm::to_vector<4>(loop.getInits());
- operands.append(newIterOperands.begin(), newIterOperands.end());
- SmallVector<Value, 4> lbOperands(loop.getLowerBoundOperands());
- SmallVector<Value, 4> ubOperands(loop.getUpperBoundOperands());
- SmallVector<Value, 4> steps(loop.getStep());
- auto lbMap = loop.getLowerBoundMap();
- auto ubMap = loop.getUpperBoundMap();
- AffineForOp newLoop =
- b.create<AffineForOp>(loop.getLoc(), lbOperands, lbMap, ubOperands, ubMap,
- loop.getStep(), operands);
- // Take the body of the original parent loop.
- newLoop.getRegion().takeBody(loop.getRegion());
- for (Value val : newIterArgs)
- newLoop.getRegion().addArgument(val.getType(), val.getLoc());
-
- // Update yield operation with new values to be added.
- if (!newYieldedValues.empty()) {
- auto yield = cast<AffineYieldOp>(newLoop.getBody()->getTerminator());
- b.setInsertionPoint(yield);
- auto yieldOperands = llvm::to_vector<4>(yield.getOperands());
- yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end());
- b.create<AffineYieldOp>(yield.getLoc(), yieldOperands);
- yield.erase();
- }
- if (replaceLoopResults) {
- for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
- loop.getNumResults()))) {
- std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
- }
- }
- return newLoop;
-}
-
//===----------------------------------------------------------------------===//
// AffineIfOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index 3ecb8664e3fd765..5053b08ee0834cd 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
@@ -361,16 +362,22 @@ static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
if (!tripCount || *tripCount != 1)
return failure();
- auto iterOperands = forOp.getInits();
auto *parentOp = forOp->getParentOp();
if (!isa<AffineForOp>(parentOp))
return failure();
- auto newOperands = forOp.getBody()->getTerminator()->getOperands();
- OpBuilder b(parentOp);
+ SmallVector<Value> newOperands;
+ llvm::append_range(newOperands,
+ forOp.getBody()->getTerminator()->getOperands());
+ IRRewriter rewriter(parentOp->getContext());
+ int64_t parentOpNumResults = parentOp->getNumResults();
// Replace the parent loop and add iteroperands and results from the `forOp`.
AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>();
- AffineForOp newLoop = replaceForOpWithNewYields(
- b, parentForOp, iterOperands, newOperands, forOp.getRegionIterArgs());
+ AffineForOp newLoop =
+ cast<AffineForOp>(*parentForOp.replaceWithAdditionalYields(
+ rewriter, forOp.getInits(), /*replaceInitOperandUsesInLoop=*/false,
+ [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+ return newOperands;
+ }));
// For sibling-fusion users, collect operations that use the results of the
// `forOp` outside the new parent loop that has absorbed all its iter args
@@ -387,7 +394,7 @@ static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
// Update the results of the `forOp` in the new loop.
for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
forOp.getResult(i).replaceAllUsesWith(
- newLoop.getResult(i + parentOp->getNumResults()));
+ newLoop.getResult(i + parentOpNumResults));
}
// For sibling-fusion users, move operations that use the results of the
// `forOp` outside the new parent loop
@@ -412,7 +419,6 @@ static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
parentBlock->getOperations().splice(Block::iterator(forOp),
forOp.getBody()->getOperations());
forOp.erase();
- parentForOp.erase();
return success();
}
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index e6c4b2f8447470c..9d8ed9b4ac93387 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -1197,9 +1197,9 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
// `unrollJamFactor` copies of its iterOperands, iter_args and yield
// operands.
SmallVector<AffineForOp, 4> newLoopsWithIterArgs;
- OpBuilder builder(forOp.getContext());
+ IRRewriter rewriter(forOp.getContext());
for (AffineForOp oldForOp : loopsWithIterArgs) {
- SmallVector<Value, 4> dupIterOperands, dupIterArgs, dupYieldOperands;
+ SmallVector<Value> dupIterOperands, dupYieldOperands;
ValueRange oldIterOperands = oldForOp.getInits();
ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
ValueRange oldYieldOperands =
@@ -1208,19 +1208,21 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
// fix iterOperands and yield operands after cloning of sub-blocks.
for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
- dupIterArgs.append(oldIterArgs.begin(), oldIterArgs.end());
dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
}
// Create a new loop with additional iterOperands, iter_args and yield
// operands. This new loop will take the loop body of the original loop.
- AffineForOp newForOp = affine::replaceForOpWithNewYields(
- builder, oldForOp, dupIterOperands, dupYieldOperands, dupIterArgs);
+ bool forOpReplaced = oldForOp == forOp;
+ AffineForOp newForOp =
+ cast<AffineForOp>(*oldForOp.replaceWithAdditionalYields(
+ rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
+ [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+ return dupYieldOperands;
+ }));
newLoopsWithIterArgs.push_back(newForOp);
// `forOp` has been replaced with a new loop.
- if (oldForOp == forOp)
+ if (forOpReplaced)
forOp = newForOp;
- assert(oldForOp.use_empty() && "old for op should not have any user");
- oldForOp.erase();
// Update `operandMaps` for `newForOp` iterArgs and results.
ValueRange newIterArgs = newForOp.getRegionIterArgs();
unsigned oldNumIterArgs = oldIterArgs.size();
@@ -1294,7 +1296,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
// into one value. For example, for %0:2 = affine.for ... and addf, we add
// %1 = arith.addf %0#0, %0#1, and replace the following uses of %0#0 with
// %1.
- builder.setInsertionPointAfter(forOp);
+ rewriter.setInsertionPointAfter(forOp);
auto loc = forOp.getLoc();
unsigned oldNumResults = forOp.getNumResults() / unrollJamFactor;
for (LoopReduction &reduction : reductions) {
@@ -1305,7 +1307,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
rhs = forOp.getResult(i * oldNumResults + pos);
// Create ops based on reduction type.
- lhs = arith::getReductionOp(reduction.kind, builder, loc, lhs, rhs);
+ lhs = arith::getReductionOp(reduction.kind, rewriter, loc, lhs, rhs);
if (!lhs)
return failure();
Operation *op = lhs.getDefiningOp();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index a9debb7bbc489a4..72bd2b409f5d52b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -842,8 +842,11 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
outerSliceOp.getMixedStrides());
rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted);
}
- scf::ForOp newForOp =
- replaceLoopWithNewYields(rewriter, forOp, initArgs, yieldOperands);
+ scf::ForOp newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
+ rewriter, initArgs, /*replaceInitOperandUsesInLoop=*/true,
+ [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+ return yieldOperands;
+ }));
LLVM_DEBUG(DBGS() << "newForOp results: " << newForOp.getNumResults()
<< "\n");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 7c6639304d97c58..1c68ca49725effb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -191,47 +191,24 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
transferWrite->moveAfter(loop);
// Rewrite `loop` with new yields by cloning and erase the original loop.
- OpBuilder b(transferRead);
- NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc,
- ArrayRef<BlockArgument> newBBArgs) {
+ IRRewriter rewriter(transferRead.getContext());
+ NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
+ ArrayRef<BlockArgument> newBBArgs) {
return SmallVector<Value>{transferWrite.getVector()};
};
- // Transfer write has been hoisted, need to update the written vector by
- // the value yielded by the newForOp.
- return TypeSwitch<Operation *, WalkResult>(loop)
- .Case<scf::ForOp>([&](scf::ForOp scfForOp) {
- auto newForOp = replaceLoopWithNewYields(
- b, scfForOp, transferRead.getVector(), yieldFn);
- transferWrite.getVectorMutable().assign(
- newForOp.getResults().back());
- changed = true;
- loop.erase();
- // Need to interrupt and restart because erasing the loop messes up
- // the walk.
- return WalkResult::interrupt();
- })
- .Case<affine::AffineForOp>([&](affine::AffineForOp affineForOp) {
- auto newForOp = replaceForOpWithNewYields(
- b, affineForOp, transferRead.getVector(),
- SmallVector<Value>{transferWrite.getVector()},
- transferWrite.getVector());
- // Replace all uses of the `transferRead` with the corresponding
- // basic block argument.
- transferRead.getVector().replaceUsesWithIf(
- newForOp.getBody()->getArguments().back(), [&](OpOperand &use) {
- Operation *user = use.getOwner();
- return newForOp->isProperAncestor(user);
- });
- transferWrite.getVectorMutable().assign(
- newForOp.getResults().back());
- changed = true;
- loop.erase();
- // Need to interrupt and restart because erasing the loop messes up
- // the walk.
- return WalkResult::interrupt();
- })
- .Default([](Operation *) { return WalkResult::interrupt(); });
+ auto maybeNewLoop = loop.replaceWithAdditionalYields(
+ rewriter, transferRead.getVector(),
+ /*replaceInitOperandUsesInLoop=*/true, yieldFn);
+ if (failed(maybeNewLoop))
+ return WalkResult::interrupt();
+
+ transferWrite.getVectorMutable().assign(
+ maybeNewLoop->getOperation()->getResults().back());
+ changed = true;
+ // Need to interrupt and restart because erasing the loop messes up
+ // the walk.
+ return WalkResult::interrupt();
});
}
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
index 7ab4ea41a2cd89d..91e0d139ec5c2f0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
@@ -363,13 +363,13 @@ static scf::ForOp hoistTransferReadWrite(
// 2. Rewrite `loop` with an additional yield. This is the quantity that is
// computed iteratively but whose storage has become loop-invariant.
- NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc,
- ArrayRef<BlockArgument> newBBArgs) {
+ NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
+ ArrayRef<BlockArgument> newBBArgs) {
return SmallVector<Value>{transferWriteOp.getVector()};
};
- auto newForOp = replaceLoopWithNewYields(
- rewriter, forOp, {transferReadOp.getVector()}, yieldFn);
- rewriter.eraseOp(forOp);
+ auto newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
+ rewriter, {transferReadOp.getVector()},
+ /*replaceInitOperandUsesInLoop=*/true, yieldFn));
// 3. Update the yield. Invariant: initArgNumber is the destination tensor.
auto yieldOp =
@@ -425,13 +425,13 @@ static scf::ForOp hoistExtractInsertSlice(RewriterBase &rewriter,
// 2. Rewrite `loop` with an additional yield. This is the quantity that is
// computed iteratively but whose storage has become loop-invariant.
- NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc,
- ArrayRef<BlockArgument> newBBArgs) {
+ NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
+ ArrayRef<BlockArgument> newBBArgs) {
return SmallVector<Value>{insertSliceOp.getSource()};
};
- auto newForOp = replaceLoopWithNewYields(rewriter, forOp,
- extractSliceOp.getResult(), yieldFn);
- rewriter.eraseOp(forOp);
+ auto newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
+ rewriter, extractSliceOp.getResult(),
+ /*replaceInitOperandUsesInLoop=*/true, yieldFn));
// 3. Update the yield. Invariant: initArgNumber is the destination tensor.
auto yieldOp =
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 8788597a1cefcfa..8d8481421e18d57 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -529,6 +529,59 @@ SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
OperandRange ForOp::getInits() { return getInitArgs(); }
+FailureOr<LoopLikeOpInterface>
+ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
+ ValueRange newInitOperands,
+ bool replaceInitOperandUsesInLoop,
+ const NewYieldValuesFn &newYieldValuesFn) {
+ // Create a new loop before the existing one, with the extra operands.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(getOperation());
+ auto inits = llvm::to_vector(getInitArgs());
+ inits.append(newInitOperands.begin(), newInitOperands.end());
+ scf::ForOp newLoop = rewriter.create<scf::ForOp>(
+ getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
+ [](OpBuilder &, Location, Value, ValueRange) {});
+
+ // Generate the new yield values and append them to the scf.yield operation.
+ auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
+ ArrayRef<BlockArgument> newIterArgs =
+ newLoop.getBody()->getArguments().take_back(newInitOperands.size());
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(yieldOp);
+ SmallVector<Value> newYieldedValues =
+ newYieldValuesFn(rewriter, getLoc(), newIterArgs);
+ assert(newInitOperands.size() == newYieldedValues.size() &&
+ "expected as many new yield values as new iter operands");
+ rewriter.updateRootInPlace(yieldOp, [&]() {
+ yieldOp.getResultsMutable().append(newYieldedValues);
+ });
+ }
+
+ // Move the loop body to the new op.
+ rewriter.mergeBlocks(getBody(), newLoop.getBody(),
+ newLoop.getBody()->getArguments().take_front(
+ getBody()->getNumArguments()));
+
+ if (replaceInitOperandUsesInLoop) {
+ // Replace all uses of `newInitOperands` with the corresponding basic block
+ // arguments.
+ for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
+ rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
+ [&](OpOperand &use) {
+ Operation *user = use.getOwner();
+ return newLoop->isProperAncestor(user);
+ });
+ }
+ }
+
+ // Replace the old loop.
+ rewriter.replaceOp(getOperation(),
+ newLoop->getResults().take_front(getNumResults()));
+ return cast<LoopLikeOpInterface>(newLoop.getOperation());
+}
+
ForOp mlir::scf::getForInductionVarOwner(Value val) {
auto ivArg = llvm::dyn_cast<BlockArgument>(val);
if (!ivArg)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index ab59eac2ac4d6f8..e2fb417f5f07a6d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -174,7 +174,7 @@ yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
MutableArrayRef<scf::ForOp> loops) {
- NewYieldValueFn yieldValueFn =
+ NewYieldValuesFn yieldValueFn =
[&](OpBuilder &b, Location loc,
ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
SmallVector<Value> inserts;
@@ -196,7 +196,6 @@ yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn,
/*replaceIterOperandsUsesInLoop =*/false);
for (const auto &loop : llvm::enumerate(loops)) {
- rewriter.eraseOp(loop.value());
loops[loop.index()] = newLoops[loop.index()];
}
return llvm::to_vector(llvm::map_range(
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 411503700eb01c3..5360c493f8f8d71 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -38,77 +38,9 @@ struct LoopParams {
};
} // namespace
-scf::ForOp
-mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
- ValueRange newIterOperands,
- const NewYieldValueFn &newYieldValuesFn,
- bool replaceIterOperandsUsesInLoop) {
- // Create a new loop before the existing one, with the extra operands.
- OpBuilder::InsertionGuard g(builder);
- builder.setInsertionPoint(loop);
- auto operands = llvm::to_vector(loop.getInitArgs());
- llvm::append_range(operands, newIterOperands);
- scf::ForOp newLoop = builder.create<scf::ForOp>(
- loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
- operands, [](OpBuilder &, Location, Value, ValueRange) {});
-
- Block *loopBody = loop.getBody();
- Block *newLoopBody = newLoop.getBody();
-
- // Move the body of the original loop to the new loop.
- newLoopBody->getOperations().splice(newLoopBody->end(),
- loopBody->getOperations());
-
- // Generate the new yield values to use by using the callback and append the
- // yield values to the scf.yield operation.
- auto yield = cast<scf::YieldOp>(newLoopBody->getTerminator());
- ArrayRef<BlockArgument> newBBArgs =
- newLoopBody->getArguments().take_back(newIterOperands.size());
- {
- OpBuilder::InsertionGuard g(builder);
- builder.setInsertionPoint(yield);
- SmallVector<Value> newYieldedValues =
- newYieldValuesFn(builder, loop.getLoc(), newBBArgs);
- assert(newIterOperands.size() == newYieldedValues.size() &&
- "expected as many new yield values as new iter operands");
- yield.getResultsMutable().append(newYieldedValues);
- }
-
- // Remap the BlockArguments from the original loop to the new loop
- // BlockArguments.
- ArrayRef<BlockArgument> bbArgs = loopBody->getArguments();
- for (auto it :
- llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size())))
- std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
-
- if (replaceIterOperandsUsesInLoop) {
- // Replace all uses of `newIterOperands` with the corresponding basic block
- // arguments.
- for (auto it : llvm::zip(newIterOperands, newBBArgs)) {
- std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) {
- Operation *user = use.getOwner();
- return newLoop->isProperAncestor(user);
- });
- }
- }
-
- // Replace all uses of the original loop with corresponding values from the
- // new loop.
- loop.replaceAllUsesWith(
- newLoop.getResults().take_front(loop.getNumResults()));
-
- // Add a fake yield to the original loop body that just returns the
- // BlockArguments corresponding to the iter_args. This makes it a no-op loop.
- // The loop is dead. The caller is expected to erase it.
- builder.setInsertionPointToEnd(loopBody);
- builder.create<scf::YieldOp>(loop->getLoc(), loop.getRegionIterArgs());
-
- return newLoop;
-}
-
SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
- OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
- ValueRange newIterOperands, const NewYieldValueFn &newYieldValueFn,
+ RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
+ ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
bool replaceIterOperandsUsesInLoop) {
if (loopNest.empty())
return {};
@@ -146,31 +78,32 @@ SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
// }
// ```
//
- // The inner most loop is handled using the `replaceLoopWithNewYields`
+ // The inner most loop is handled using the `replaceWithAdditionalYields`
// that works on a single loop.
if (loopNest.size() == 1) {
- auto innerMostLoop = replaceLoopWithNewYields(
- builder, loopNest.back(), newIterOperands, newYieldValueFn,
- replaceIterOperandsUsesInLoop);
+ auto innerMostLoop =
+ cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
+ rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
+ newYieldValuesFn));
return {innerMostLoop};
}
// The outer loops are modified by calling this method recursively
// - The return value of the inner loop is the value yielded by this loop.
// - The region iter args of this loop are the init_args for the inner loop.
SmallVector<scf::ForOp> newLoopNest;
- NewYieldValueFn fn =
+ NewYieldValuesFn fn =
[&](OpBuilder &innerBuilder, Location loc,
ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
- newLoopNest = replaceLoopNestWithNewYields(builder, loopNest.drop_front(),
- innerNewBBArgs, newYieldValueFn,
+ newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(),
+ innerNewBBArgs, newYieldValuesFn,
replaceIterOperandsUsesInLoop);
return llvm::to_vector(llvm::map_range(
newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
[](OpResult r) -> Value { return r; }));
};
scf::ForOp outerMostLoop =
- replaceLoopWithNewYields(builder, loopNest.front(), newIterOperands, fn,
- replaceIterOperandsUsesInLoop);
+ cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
+ rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
return newLoopNest;
}
diff --git a/mlir/test/Transforms/scf-replace-with-new-yields.mlir b/mlir/test/Transforms/scf-replace-with-new-yields.mlir
index 802f86e6806b318..d7af15f52e72a9b 100644
--- a/mlir/test/Transforms/scf-replace-with-new-yields.mlir
+++ b/mlir/test/Transforms/scf-replace-with-new-yields.mlir
@@ -15,7 +15,4 @@ func.func @doubleup(%lb: index, %ub: index, %step: index, %extra_arg: f32) -> f3
// CHECK: %[[DOUBLE:.+]] = arith.addf %[[INIT1]], %[[INIT1]]
// CHECK: %[[DOUBLE2:.+]] = arith.addf %[[DOUBLE]], %[[DOUBLE]]
// CHECK: scf.yield %[[DOUBLE]], %[[DOUBLE2]]
-// CHECK: %[[OLDLOOP:.+]] = scf.for
-// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[ARG]])
-// CHECK: scf.yield %[[INIT]]
// CHECK: return %[[NEWLOOP]]#0
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 455c9234b8c93de..1d40615305c02c4 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -50,19 +50,23 @@ struct TestSCFForUtilsPass
auto newInitValues = forOp.getInitArgs();
if (newInitValues.empty())
return;
- NewYieldValueFn fn = [&](OpBuilder &b, Location loc,
- ArrayRef<BlockArgument> newBBArgs) {
- Block *block = newBBArgs.front().getOwner();
+ auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+ SmallVector<Value> oldYieldValues(yieldOp.getResults().begin(),
+ yieldOp.getResults().end());
+ NewYieldValuesFn fn = [&](OpBuilder &b, Location loc,
+ ArrayRef<BlockArgument> newBBArgs) {
SmallVector<Value> newYieldValues;
- for (auto yieldVal :
- cast<scf::YieldOp>(block->getTerminator()).getResults()) {
+ for (auto yieldVal : oldYieldValues) {
newYieldValues.push_back(
b.create<arith::AddFOp>(loc, yieldVal, yieldVal));
}
return newYieldValues;
};
- OpBuilder b(forOp);
- replaceLoopWithNewYields(b, forOp, newInitValues, fn);
+ IRRewriter rewriter(forOp.getContext());
+ if (failed(forOp.replaceWithAdditionalYields(
+ rewriter, newInitValues, /*replaceInitOperandUsesInLoop=*/true,
+ fn)))
+ signalPassFailure();
});
}
}
>From cbf54441fe26cafa050ee7f4f5c2bc68de0dde92 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 25 Sep 2023 12:29:13 +0200
Subject: [PATCH 2/2] [mlir][Interfaces] `LoopLikeOpInterface`: Add helper to
get yielded values
Add a new interface method that returns the yielded values.
Also add a verifier that checks the number of inits/iter_args/yielded values. Most of the checked invariants (but not all of them) are already covered by the `RegionBranchOpInterface`, but the `LoopLikeOpInterface` now provides (additional) error messages that are easier to read.
---
.../include/flang/Optimizer/Dialect/FIROps.td | 4 +-
flang/lib/Lower/Bridge.cpp | 6 +--
flang/lib/Lower/ConvertExpr.cpp | 4 +-
flang/lib/Optimizer/Dialect/FIROps.cpp | 8 ++--
.../Transforms/SimplifyIntrinsics.cpp | 6 +--
.../mlir/Dialect/Affine/IR/AffineOps.td | 3 +-
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 11 +++++-
.../mlir/Interfaces/LoopLikeInterface.h | 5 +++
.../mlir/Interfaces/LoopLikeInterface.td | 30 +++++++++++++++
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 4 ++
.../Linalg/Transforms/HoistPadding.cpp | 5 +--
mlir/lib/Dialect/SCF/IR/SCF.cpp | 27 ++++++++------
.../BufferizableOpInterfaceImpl.cpp | 3 +-
.../SCF/Transforms/LoopCanonicalization.cpp | 5 +--
mlir/lib/Interfaces/LoopLikeInterface.cpp | 37 +++++++++++++++++++
mlir/test/Dialect/SCF/invalid.mlir | 30 ++++++++++++++-
mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp | 5 +--
17 files changed, 151 insertions(+), 42 deletions(-)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index a57add9f731979d..483c679eb623ce6 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2106,7 +2106,7 @@ def fir_DoLoopOp : region_Op<"do_loop",
mlir::OpBuilder getBodyBuilder() {
return mlir::OpBuilder(getBody(), std::prev(getBody()->end()));
}
- mlir::Block::BlockArgListType getRegionIterArgs() {
+ mlir::Block::BlockArgListType getIterArgs() {
return getBody()->getArguments().drop_front();
}
mlir::Operation::operand_range getIterOperands() {
@@ -2257,7 +2257,7 @@ def fir_IterWhileOp : region_Op<"iterate_while",
mlir::OpBuilder getBodyBuilder() {
return mlir::OpBuilder(getBody(), std::prev(getBody()->end()));
}
- mlir::Block::BlockArgListType getRegionIterArgs() {
+ mlir::Block::BlockArgListType getIterArgs() {
return getBody()->getArguments().drop_front();
}
mlir::Operation::operand_range getIterOperands() {
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 2fb3c2c3818d42d..61db6e0b1bfe5f9 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -1789,7 +1789,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/*finalCountValue=*/true,
builder->createConvert(loc, loopVarType, lowerValue));
builder->setInsertionPointToStart(info.doLoop.getBody());
- loopValue = info.doLoop.getRegionIterArgs()[0];
+ loopValue = info.doLoop.getIterArgs()[0];
}
// Update the loop variable value in case it has non-index references.
builder->create<fir::StoreOp>(loc, loopValue, info.loopVariable);
@@ -2105,9 +2105,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
auto lp = builder->create<fir::DoLoopOp>(
loc, lb, ub, by, /*unordered=*/true,
/*finalCount=*/false, explicitIterSpace.getInnerArgs());
- if ((!loops.empty() || !outermost) && !lp.getRegionIterArgs().empty())
+ if ((!loops.empty() || !outermost) && !lp.getIterArgs().empty())
builder->create<fir::ResultOp>(loc, lp.getResults());
- explicitIterSpace.setInnerArgs(lp.getRegionIterArgs());
+ explicitIterSpace.setInnerArgs(lp.getIterArgs());
builder->setInsertionPointToStart(lp.getBody());
forceControlVariableBinding(ctrlVar, lp.getInductionVar());
loops.push_back(lp);
diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index 8788e82b59a8df0..9871e4e0dedd3be 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -4342,7 +4342,7 @@ class ArrayExprLowering {
loop = builder.create<fir::DoLoopOp>(
loc, zero, i.value(), one, isUnordered(),
/*finalCount=*/false, mlir::ValueRange{innerArg});
- innerArg = loop.getRegionIterArgs().front();
+ innerArg = loop.getIterArgs().front();
if (explicitSpaceIsActive())
explicitSpace->setInnerArg(0, innerArg);
} else {
@@ -6328,7 +6328,7 @@ class ArrayExprLowering {
auto insPt = builder.saveInsertionPoint();
builder.setInsertionPointToStart(loop.getBody());
// Thread mem inside the loop via loop argument.
- mem = loop.getRegionIterArgs()[0];
+ mem = loop.getIterArgs()[0];
mlir::Type eleRefTy = builder.getRefType(eleTy);
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 962b87acd5a8050..1764c7812a78598 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1905,7 +1905,7 @@ mlir::LogicalResult fir::IterWhileOp::verify() {
return emitOpError(
"mismatch in number of basic block args and defined values");
auto iterOperands = getIterOperands();
- auto iterArgs = getRegionIterArgs();
+ auto iterArgs = getIterArgs();
auto opResults = getFinalValue() ? getResults().drop_front() : getResults();
unsigned i = 0u;
for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
@@ -1925,7 +1925,7 @@ void fir::IterWhileOp::print(mlir::OpAsmPrinter &p) {
p << " (" << getInductionVar() << " = " << getLowerBound() << " to "
<< getUpperBound() << " step " << getStep() << ") and (";
assert(hasIterOperands());
- auto regionArgs = getRegionIterArgs();
+ auto regionArgs = getIterArgs();
auto operands = getIterOperands();
p << regionArgs.front() << " = " << *operands.begin() << ")";
if (regionArgs.size() > 1) {
@@ -2194,7 +2194,7 @@ mlir::LogicalResult fir::DoLoopOp::verify() {
return emitOpError(
"mismatch in number of basic block args and defined values");
auto iterOperands = getIterOperands();
- auto iterArgs = getRegionIterArgs();
+ auto iterArgs = getIterArgs();
auto opResults = getFinalValue() ? getResults().drop_front() : getResults();
unsigned i = 0u;
for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
@@ -2218,7 +2218,7 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
p << " unordered";
if (hasIterOperands()) {
p << " iter_args(";
- auto regionArgs = getRegionIterArgs();
+ auto regionArgs = getIterArgs();
auto operands = getIterOperands();
llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
p << std::get<0>(it) << " = " << std::get<1>(it);
diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index 3eddb9e61ae3b3d..7e6f668a012a0a0 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -316,7 +316,7 @@ genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
auto loop = builder.create<OP>(loc, zeroIdx, loopCount, step,
unorderedOrInitialLoopCond,
/*finalCountValue=*/false, init);
- init = loop.getRegionIterArgs()[resultIndex];
+ init = loop.getIterArgs()[resultIndex];
indices.push_back(loop.getInductionVar());
// Set insertion point to the loop body so that the next loop
// is inserted inside the current one.
@@ -422,7 +422,7 @@ genMinlocReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
auto loop =
builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false,
/*finalCountValue=*/false, init);
- init = loop.getRegionIterArgs()[0];
+ init = loop.getIterArgs()[0];
indices.push_back(loop.getInductionVar());
// Set insertion point to the loop body so that the next loop
// is inserted inside the current one.
@@ -952,7 +952,7 @@ static void genRuntimeDotBody(fir::FirOpBuilder &builder,
auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
/*unordered=*/false,
/*finalCountValue=*/false, zero);
- mlir::Value sumVal = loop.getRegionIterArgs()[0];
+ mlir::Value sumVal = loop.getIterArgs()[0];
// Begin loop code
mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index d8ef0506d0822d7..6eb0a4468fd91b5 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -120,7 +120,8 @@ def AffineForOp : Affine_Op<"for",
[AutomaticAllocationScope, ImplicitAffineTerminator, ConditionallySpeculatable,
RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
- "getSingleUpperBound", "replaceWithAdditionalYields"]>,
+ "getSingleUpperBound", "getYieldedValues",
+ "replaceWithAdditionalYields"]>,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>]> {
let summary = "for operation";
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index e62015b84888a18..c63703327f5bb34 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -121,8 +121,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
def ForOp : SCF_Op<"for",
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInits", "getSingleInductionVar", "getSingleLowerBound",
- "getSingleStep", "getSingleUpperBound", "promoteIfSingleIteration",
- "replaceWithAdditionalYields"]>,
+ "getSingleStep", "getSingleUpperBound", "getYieldedValues",
+ "promoteIfSingleIteration", "replaceWithAdditionalYields"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
ConditionallySpeculatable,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
@@ -242,9 +242,11 @@ def ForOp : SCF_Op<"for",
function_ref<void(OpBuilder &, Location, Value, ValueRange)>;
Value getInductionVar() { return getBody()->getArgument(0); }
+
Block::BlockArgListType getRegionIterArgs() {
return getBody()->getArguments().drop_front(getNumInductionVars());
}
+
/// Return the `index`-th region iteration argument.
BlockArgument getRegionIterArg(unsigned index) {
assert(index < getNumRegionIterArgs() &&
@@ -1081,6 +1083,11 @@ def WhileOp : SCF_Op<"while",
ConditionOp getConditionOp();
YieldOp getYieldOp();
+
+ /// Return the values that are yielded from the "before" region (by the
+ /// ConditionOp).
+ ValueRange getYieldedValues();
+
Block::BlockArgListType getBeforeArguments();
Block::BlockArgListType getAfterArguments();
Block *getBeforeBody() { return &getBefore().front(); }
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index 0eebb984e5897ae..7c7d378d0590ab1 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -24,6 +24,11 @@ class RewriterBase;
/// arguments in `newBbArgs`.
using NewYieldValuesFn = std::function<SmallVector<Value>(
OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs)>;
+
+namespace detail {
+/// Verify invariants of the LoopLikeOpInterface.
+LogicalResult verifyLoopLikeOpInterface(Operation *op);
+} // namespace detail
} // namespace mlir
/// Include the generated interface declarations.
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index ded0a29292ff6f0..5e3c0db965156c4 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -20,6 +20,19 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
Contains helper functions to query properties and perform transformations
of a loop. Operations that implement this interface will be considered by
loop-invariant code motion.
+
+ Loop-carried variables can be exposed through this interface. There are
+ 3 components to a loop-carried variable.
+ - The "region iter_arg" is the block argument of the entry block that
+ represents the loop-carried variable in each iteration.
+ - The "init value" is an operand of the loop op that serves as the initial
+ region iter_arg value for the first iteration (if any).
+ - The "yielded" value is the value that is forwarded from one iteration to
+ serve as the region iter_arg of the next iteration.
+
+ If one of the respective interface methods is implemented, so must the other
+ two. The interface verifier ensures that the number of types of the region
+ iter_args, init values and yielded values match.
}];
let cppNamespace = "::mlir";
@@ -141,6 +154,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
return ::mlir::Block::BlockArgListType();
}]
>,
+ InterfaceMethod<[{
+ Return the range of values that are yielded to the next iterations.
+ }],
+ /*retTy=*/"::mlir::ValueRange",
+ /*methodName=*/"getYieldedValues",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::ValueRange();
+ }]
+ >,
InterfaceMethod<[{
Append the specified additional "init" operands: replace this loop with
a new loop that has the additional init operands. The loop body of
@@ -192,6 +216,12 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
});
}
}];
+
+ let verifyWithRegions = 1;
+
+ let verify = [{
+ return detail::verifyLoopLikeOpInterface($_op);
+ }];
}
#endif // MLIR_INTERFACES_LOOPLIKEINTERFACE
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 6c060c90e24af82..4109167de44e860 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2193,6 +2193,10 @@ unsigned AffineForOp::getNumIterOperands() {
return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
}
+ValueRange AffineForOp::getYieldedValues() {
+ return cast<AffineYieldOp>(getBody()->getTerminator()).getOperands();
+}
+
void AffineForOp::print(OpAsmPrinter &p) {
p << ' ';
p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 72bd2b409f5d52b..8fef99bb375095a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -811,8 +811,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
- auto yieldOp = cast<scf::YieldOp>(forOp.getBody(0)->getTerminator());
- auto yieldingExtractSliceOp = yieldOp->getOperand(iterArgNumber)
+ auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber]
.getDefiningOp<tensor::ExtractSliceOp>();
if (!yieldingExtractSliceOp)
return tensor::ExtractSliceOp();
@@ -826,7 +825,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
SmallVector<Value> initArgs = forOp.getInitArgs();
initArgs[iterArgNumber] = hoistedPackedTensor;
- SmallVector<Value> yieldOperands = yieldOp.getOperands();
+ SmallVector<Value> yieldOperands = llvm::to_vector(forOp.getYieldedValues());
yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource();
int64_t numOriginalForOpResults = initArgs.size();
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 8d8481421e18d57..508227d6e7ce4b0 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -400,7 +400,7 @@ LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
// Replace all results with the yielded values.
auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
- rewriter.replaceAllUsesWith(getResults(), yieldOp.getOperands());
+ rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
// Replace block arguments with lower bound (replacement for IV) and
// iter_args.
@@ -772,27 +772,26 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const final {
bool canonicalize = false;
- Block &block = forOp.getRegion().front();
- auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
// An internal flat vector of block transfer
// arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
// transformed block argument mappings. This plays the role of a
// IRMapping for the particular use case of calling into
// `inlineBlockBefore`.
+ int64_t numResults = forOp.getNumResults();
SmallVector<bool, 4> keepMask;
- keepMask.reserve(yieldOp.getNumOperands());
+ keepMask.reserve(numResults);
SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
newResultValues;
- newBlockTransferArgs.reserve(1 + forOp.getInitArgs().size());
+ newBlockTransferArgs.reserve(1 + numResults);
newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
newIterArgs.reserve(forOp.getInitArgs().size());
- newYieldValues.reserve(yieldOp.getNumOperands());
- newResultValues.reserve(forOp.getNumResults());
+ newYieldValues.reserve(numResults);
+ newResultValues.reserve(numResults);
for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside
forOp.getRegionIterArgs(), // iter inside region
forOp.getResults(), // op results
- yieldOp.getOperands() // iter yield
+ forOp.getYieldedValues() // iter yield
)) {
// Forwarded is `true` when:
// 1) The region `iter` argument is yielded.
@@ -946,12 +945,10 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
return failure();
// If the loop is empty, iterates at least once, and only returns values
// defined outside of the loop, remove it and replace it with yield values.
- auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
- auto yieldOperands = yieldOp.getOperands();
- if (llvm::any_of(yieldOperands,
+ if (llvm::any_of(op.getYieldedValues(),
[&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
return failure();
- rewriter.replaceOp(op, yieldOperands);
+ rewriter.replaceOp(op, op.getYieldedValues());
return success();
}
};
@@ -1224,6 +1221,10 @@ std::optional<APInt> ForOp::getConstantStep() {
return {};
}
+ValueRange ForOp::getYieldedValues() {
+ return cast<scf::YieldOp>(getBody()->getTerminator()).getResults();
+}
+
Speculation::Speculatability ForOp::getSpeculatability() {
// `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
// and End.
@@ -3205,6 +3206,8 @@ YieldOp WhileOp::getYieldOp() {
return cast<YieldOp>(getAfterBody()->getTerminator());
}
+ValueRange WhileOp::getYieldedValues() { return getYieldOp().getResults(); }
+
Block::BlockArgListType WhileOp::getBeforeArguments() {
return getBeforeBody()->getArguments();
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index bcbc693a9742ccc..22c893457c6f59e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -489,9 +489,8 @@ struct ForOpInterface
auto forOp = cast<scf::ForOp>(op);
OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
- auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
bool equivalentYield = state.areEquivalentBufferizedValues(
- bbArg, yieldOp->getOperand(opResult.getResultNumber()));
+ bbArg, forOp.getYieldedValues()[opResult.getResultNumber()]);
return equivalentYield ? BufferRelation::Equivalent
: BufferRelation::Unknown;
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 0cd19fbefa8ef98..43e79d309c66784 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -36,10 +36,9 @@ using namespace mlir::scf;
/// type of the corresponding basic block argument of the loop.
/// Note: This function handles only simple cases. Expand as needed.
static bool isShapePreserving(ForOp forOp, int64_t arg) {
- auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
- assert(arg < static_cast<int64_t>(yieldOp.getResults().size()) &&
+ assert(arg < static_cast<int64_t>(forOp.getNumResults()) &&
"arg is out of bounds");
- Value value = yieldOp.getResults()[arg];
+ Value value = forOp.getYieldedValues()[arg];
while (value) {
if (value == forOp.getRegionIterArgs()[arg])
return true;
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index 781a21bb3ecd3a0..15a816f4e448839 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -52,3 +52,40 @@ bool LoopLikeOpInterface::blockIsInLoop(Block *block) {
}
return false;
}
+
+LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
+ // Note: These invariants are also verified by the RegionBranchOpInterface,
+ // but the LoopLikeOpInterface provides better error messages.
+ auto loopLikeOp = cast<LoopLikeOpInterface>(op);
+
+ // Verify number of inits/iter_args/yielded values.
+ if (loopLikeOp.getInits().size() != loopLikeOp.getRegionIterArgs().size())
+ return op->emitOpError("different number of inits and region iter_args: ")
+ << loopLikeOp.getInits().size()
+ << " != " << loopLikeOp.getRegionIterArgs().size();
+ if (loopLikeOp.getRegionIterArgs().size() !=
+ loopLikeOp.getYieldedValues().size())
+ return op->emitOpError(
+ "different number of region iter_args and yielded values: ")
+ << loopLikeOp.getRegionIterArgs().size()
+ << " != " << loopLikeOp.getYieldedValues().size();
+
+ // Verify types of inits/iter_args/yielded values.
+ int64_t i = 0;
+ for (const auto it :
+ llvm::zip_equal(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs(),
+ loopLikeOp.getYieldedValues())) {
+ if (std::get<0>(it).getType() != std::get<1>(it).getType())
+ op->emitOpError(std::to_string(i))
+ << "-th init and " << i << "-th region iter_arg have different type: "
+ << std::get<0>(it).getType() << " != " << std::get<1>(it).getType();
+ if (std::get<1>(it).getType() != std::get<2>(it).getType())
+ op->emitOpError(std::to_string(i))
+ << "-th region iter_arg and " << i
+ << "-th yielded value have different type: "
+ << std::get<1>(it).getType() << " != " << std::get<2>(it).getType();
+ ++i;
+ }
+
+ return success();
+}
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index f6044ad10829227..1b2c3f563195c52 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -96,6 +96,32 @@ func.func @not_enough_loop_results(%arg0: index, %init: f32) {
// -----
+func.func @too_many_iter_args(%arg0: index, %init: f32) {
+ // expected-error @below{{different number of inits and region iter_args: 1 != 2}}
+ %x = "scf.for"(%arg0, %arg0, %arg0, %init) (
+ {
+ ^bb0(%i0 : index, %iter: f32, %iter2: f32):
+ scf.yield %iter, %iter : f32, f32
+ }
+ ) : (index, index, index, f32) -> (f32)
+ return
+}
+
+// -----
+
+func.func @too_few_yielded_values(%arg0: index, %init: f32) {
+ // expected-error @below{{different number of region iter_args and yielded values: 2 != 1}}
+ %x, %x2 = "scf.for"(%arg0, %arg0, %arg0, %init, %init) (
+ {
+ ^bb0(%i0 : index, %iter: f32, %iter2: f32):
+ scf.yield %iter : f32
+ }
+ ) : (index, index, index, f32, f32) -> (f32, f32)
+ return
+}
+
+// -----
+
func.func @loop_if_not_i1(%arg0: index) {
// expected-error at +1 {{operand #0 must be 1-bit signless integer}}
"scf.if"(%arg0) ({}, {}) : (index) -> ()
@@ -422,7 +448,8 @@ func.func @std_for_operands_mismatch_3(%arg0 : index, %arg1 : index, %arg2 : ind
func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : index) {
%s0 = arith.constant 0.0 : f32
%t0 = arith.constant 1.0 : f32
- // expected-error @+1 {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
+ // expected-error @below {{1-th region iter_arg and 1-th yielded value have different type: 'f32' != 'i32'}}
+ // expected-error @below {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
%result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2
iter_args(%si = %s0, %ti = %t0) -> (f32, f32) {
%sn = arith.addf %si, %si : f32
@@ -432,7 +459,6 @@ func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : ind
return
}
-
// -----
func.func @parallel_invalid_yield(
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 1d40615305c02c4..565d07669792f1b 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -50,9 +50,8 @@ struct TestSCFForUtilsPass
auto newInitValues = forOp.getInitArgs();
if (newInitValues.empty())
return;
- auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
- SmallVector<Value> oldYieldValues(yieldOp.getResults().begin(),
- yieldOp.getResults().end());
+ SmallVector<Value> oldYieldValues =
+ llvm::to_vector(forOp.getYieldedValues());
NewYieldValuesFn fn = [&](OpBuilder &b, Location loc,
ArrayRef<BlockArgument> newBBArgs) {
SmallVector<Value> newYieldValues;
More information about the flang-commits
mailing list