[Mlir-commits] [mlir] [MLIR][SCF] Add canonicalization pattern to fold away iter args of scf.forall (PR #90189)
Abhishek Varma
llvmlistbot at llvm.org
Thu May 2 01:32:26 PDT 2024
https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/90189
>From f71cf509c45f61f106b5ffd7974d17377b6e646f Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 25 Apr 2024 11:01:04 +0000
Subject: [PATCH 1/2] [MLIR][SCF] Add canonicalization pattern to fold away
iter args of scf.forall
-- This commit adds a canonicalization pattern to fold away iter args
of scf.forall if :-
a. The corresponding tied result has no use.
b. It is not being modified within the loop.
Signed-off-by: Abhishek Varma <avarma094 at gmail.com>
---
mlir/lib/Dialect/SCF/IR/CMakeLists.txt | 1 +
mlir/lib/Dialect/SCF/IR/SCF.cpp | 200 +++++++++++++++++++++++-
mlir/test/Dialect/SCF/canonicalize.mlir | 80 ++++++++++
3 files changed, 280 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index 423e1c3e1e042c..6e5d80078e8022 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRSCFDialect
MLIRIR
MLIRLoopLikeInterface
MLIRSideEffectInterfaces
+ MLIRSubsetOpInterface
MLIRTensorDialect
MLIRValueBoundsOpInterface
)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 7a1aafc9f1c2f9..355cfc8b3ee626 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/InliningUtils.h"
@@ -1509,6 +1510,203 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
}
};
+/// The following canonicalization pattern folds the iter arguments of
+/// scf.forall op if :-
+/// 1. The corresponding result has zero uses.
+/// 2. The iter argument is NOT being modified within the loop body.
+/// uses.
+///
+/// Example of first case :-
+/// INPUT:
+/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
+/// {
+/// ...
+/// <SOME USE OF %arg0>
+/// <SOME USE OF %arg1>
+/// <SOME USE OF %arg2>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %arg1>
+/// <STORE OP WITH DESTINATION %arg0>
+/// <STORE OP WITH DESTINATION %arg2>
+/// }
+/// }
+/// return %res#1
+///
+/// OUTPUT:
+/// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
+/// {
+/// ...
+/// <SOME USE OF %a>
+/// <SOME USE OF %new_arg0>
+/// <SOME USE OF %c>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %new_arg0>
+/// }
+/// }
+/// return %res
+///
+/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
+/// scf.forall is replaced by their corresponding operands.
+/// 2. The canonicalization assumes that there are no <STORE OP WITH
+/// DESTINATION *> ops within the body of the scf.forall except within
+/// scf.forall.in_parallel terminator.
+/// 3. The order of the <STORE OP WITH DESTINATION *> can be arbitrary
+/// within scf.forall.in_parallel - the code below takes care of this
+/// by traversing the uses of the corresponding iter arg.
+///
+/// Example of second case :-
+/// INPUT:
+/// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
+/// {
+/// ...
+/// <SOME USE OF %arg0>
+/// <SOME USE OF %arg1>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %arg1>
+/// }
+/// }
+/// return %res#0, %res#1
+///
+/// OUTPUT:
+/// %res = scf.forall ... shared_outs(%new_arg0 = %b)
+/// {
+/// ...
+/// <SOME USE OF %a>
+/// <SOME USE OF %new_arg0>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %new_arg0>
+/// }
+/// }
+/// return %a, %res
+struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
+ using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+ /// Utility function that checks if a candidate value satisifies any of the
+ /// conditions (see above doc comment) to make it viable for folding away.
+ static bool isCandidateValueToDelete(Value result, BlockArgument blockArg) {
+ if (result.use_empty()) {
+ return true;
+ }
+ Value::user_range users = blockArg.getUsers();
+ return llvm::all_of(users, [&](Operation *user) {
+ return !isa<SubsetInsertionOpInterface>(user);
+ });
+ }
+
+ LogicalResult matchAndRewrite(ForallOp forallOp,
+ PatternRewriter &rewriter) const final {
+ scf::InParallelOp terminatorOp = forallOp.getTerminator();
+ SmallVector<Operation *> yieldingOps = llvm::map_to_vector(
+ terminatorOp.getYieldingOps(), [](Operation &op) { return &op; });
+
+ // The following check should indeed be part of SCF::ForallOp::verify.
+ SmallVector<SubsetInsertionOpInterface> subsetInsertionOpInterfaceOps;
+ for (Operation *op : yieldingOps) {
+ if (auto subsetInsertionOpInterfaceOp =
+ dyn_cast<SubsetInsertionOpInterface>(op)) {
+ subsetInsertionOpInterfaceOps.push_back(subsetInsertionOpInterfaceOp);
+ continue;
+ }
+ return failure();
+ }
+
+ // Step 1: For a given i-th result of scf.forall, check the following :-
+ // a. If it has any use.
+ // b. If the corresponding iter argument is being modified within
+ // the loop.
+ //
+ // Based on the check we maintain the following :-
+ // a. `resultToDelete` - i-th result of scf.forall that'll be
+ // deleted.
+ // b. `resultToReplace` - i-th result of the old scf.forall
+ // whose uses will be replaced by the new scf.forall.
+ // c. `newOuts` - the shared_outs' operand of the new scf.forall
+ // corresponding to the i-th result with at least one use.
+ // d. `mapping` - mapping the old iter block argument of scf.forall
+ // with the corresponding shared_outs' operand. This will be
+ // used when creating a new scf.forall op.
+ SmallVector<OpResult> resultToDelete;
+ SmallVector<Value> resultToReplace;
+ SmallVector<Value> newOuts;
+ IRMapping mapping;
+ for (OpResult result : forallOp.getResults()) {
+ OpOperand *opOperand = forallOp.getTiedOpOperand(result);
+ BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
+ if (isCandidateValueToDelete(result, blockArg)) {
+ resultToDelete.push_back(result);
+ mapping.map(blockArg, opOperand->get());
+ } else {
+ resultToReplace.push_back(result);
+ newOuts.push_back(opOperand->get());
+ }
+ }
+
+ // Return early if all results of scf.forall has at least one use and being
+ // modified within the loop.
+ if (resultToDelete.empty()) {
+ return failure();
+ }
+
+ // Step 2: For the the i-th result, do the following :-
+ // a. Fetch the corresponding BlockArgument.
+ // b. Look for an op within scf.forall.in_parallel whose destination
+ // operand is the BlockArgument fetched in step a.
+ // c. Remove the operation fetched in b.
+ // d. For any use of the BlockArgument in the body of the scf.forall
+ // replace it with the corresponding Output value.
+ for (OpResult result : resultToDelete) {
+ OpOperand *opOperand = forallOp.getTiedOpOperand(result);
+ BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
+ Value::user_range users = blockArg.getUsers();
+ Operation *terminatorOperationToDelete = nullptr;
+ for (Operation *user : users) {
+ if (auto subsetInsertionOpInterfaceOp =
+ dyn_cast<SubsetInsertionOpInterface>(user)) {
+ if (subsetInsertionOpInterfaceOp.getDestinationOperand().get() ==
+ blockArg) {
+ terminatorOperationToDelete = subsetInsertionOpInterfaceOp;
+ break;
+ }
+ }
+ }
+ if (terminatorOperationToDelete)
+ rewriter.eraseOp(terminatorOperationToDelete);
+ }
+
+ // Step 3. Create a new scf.forall op with the new shared_outs' operands
+ // fetched earlier
+ auto newforallOp = rewriter.create<scf::ForallOp>(
+ forallOp.getLoc(), forallOp.getMixedLowerBound(),
+ forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
+ forallOp.getMapping());
+
+ // Step 4. Clone the region of the old scf.forall into the newly created
+ // scf.forall using the IRMapping formed in Step 1.
+ newforallOp.getBodyRegion().getBlocks().clear();
+ rewriter.cloneRegionBefore(forallOp.getRegion(), newforallOp.getRegion(),
+ newforallOp.getRegion().begin(), mapping);
+
+ // Step 5. Replace the uses of result of old scf.forall with that of the new
+ // scf.forall.
+ for (auto &&[oldResult, newResult] :
+ llvm::zip(resultToReplace, newforallOp->getResults())) {
+ rewriter.replaceAllUsesWith(oldResult, newResult);
+ }
+ // Step 6. Replace the uses of those values that either has no use or are
+ // not being modified within the loop with the corresponding
+ // OpOperand.
+ for (OpResult oldResult : resultToDelete) {
+ rewriter.replaceAllUsesWith(oldResult,
+ forallOp.getTiedOpOperand(oldResult)->get());
+ }
+ return success();
+ }
+};
+
struct ForallOpSingleOrZeroIterationDimsFolder
: public OpRewritePattern<ForallOp> {
using OpRewritePattern<ForallOp>::OpRewritePattern;
@@ -1667,7 +1865,7 @@ struct FoldTensorCastOfOutputIntoForallOp
void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
- ForallOpControlOperandsFolder,
+ ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
ForallOpSingleOrZeroIterationDimsFolder>(context);
}
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index b4c9ed4db94e0e..9b379ad15f1ecf 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1735,6 +1735,86 @@ func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall(
// -----
+#map = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+module {
+ func.func @fold_iter_args_not_being_modified_within_scfforall(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 4.200000e+01 : f32
+ %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
+ %dim = tensor.dim %arg1, %c0 : tensor<?xf32>
+ %1 = affine.apply #map()[%dim, %arg0]
+ %2:2 = scf.forall (%arg3) in (%1) shared_outs(%arg4 = %arg1, %arg5 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
+ %3 = affine.apply #map1(%arg3)[%arg0]
+ %4 = affine.min #map2(%arg3)[%dim, %arg0]
+ %extracted_slice0 = tensor.extract_slice %arg4[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+ %extracted_slice1 = tensor.extract_slice %arg5[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+ %5 = linalg.elemwise_unary ins(%extracted_slice0 : tensor<?xf32>) outs(%extracted_slice1 : tensor<?xf32>) -> tensor<?xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %5 into %arg5[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ return %2#0, %2#1 : tensor<?xf32>, tensor<?xf32>
+ }
+}
+// CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall
+// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK-SAME: shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
+// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
+// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
+// CHECK: %[[ELEM:.*]] = linalg.elemwise_unary ins(%[[OPERAND0]] : tensor<?xf32>) outs(%[[OPERAND1]] : tensor<?xf32>) -> tensor<?xf32>
+// CHECK: scf.forall.in_parallel {
+// CHECK-NEXT: tensor.parallel_insert_slice %[[ELEM]] into %[[ITER_ARG_5]]
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[ARG1]], %[[RESULT]]
+
+// -----
+
+#map = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+module {
+ func.func @fold_iter_args_with_no_use_of_result_scfforall(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>) -> tensor<?xf32> {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+ %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
+ %dim = tensor.dim %arg1, %c0 : tensor<?xf32>
+ %1 = affine.apply #map()[%dim, %arg0]
+ %2:3 = scf.forall (%arg4) in (%1) shared_outs(%arg5 = %arg1, %arg6 = %arg2, %arg7 = %arg3) -> (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
+ %3 = affine.apply #map1(%arg4)[%arg0]
+ %4 = affine.min #map2(%arg4)[%dim, %arg0]
+ %extracted_slice = tensor.extract_slice %arg5[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+ %extracted_slice_0 = tensor.extract_slice %arg6[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+ %extracted_slice_1 = tensor.extract_slice %arg7[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+ %extracted_slice_2 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+ %5 = linalg.elemwise_unary ins(%extracted_slice : tensor<?xf32>) outs(%extracted_slice_1 : tensor<?xf32>) -> tensor<?xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %5 into %arg6[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ tensor.parallel_insert_slice %extracted_slice into %arg5[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ tensor.parallel_insert_slice %extracted_slice_0 into %arg7[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ return %2#1 : tensor<?xf32>
+ }
+}
+// CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall
+// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK-SAME: shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
+// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
+// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
+// CHECK: %[[ELEM:.*]] = linalg.elemwise_unary ins(%[[OPERAND0]] : tensor<?xf32>) outs(%[[OPERAND1]] : tensor<?xf32>) -> tensor<?xf32>
+// CHECK: scf.forall.in_parallel {
+// CHECK-NEXT: tensor.parallel_insert_slice %[[ELEM]] into %[[ITER_ARG_6]]
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[RESULT]]
+
+// -----
+
func.func @index_switch_fold() -> (f32, f32) {
%switch_cst = arith.constant 1: index
%0 = scf.index_switch %switch_cst -> f32
>From bb3973ad47421a87615315a215f5d1f792d9b7c5 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 2 May 2024 08:30:52 +0000
Subject: [PATCH 2/2] Address review comments
---
mlir/include/mlir/Dialect/SCF/IR/SCF.h | 1 +
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 6 ++
mlir/lib/Dialect/SCF/IR/SCF.cpp | 93 +++++++++++++---------
3 files changed, 64 insertions(+), 36 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index 644118ca884c6b..8edc1a6cc04f26 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -24,6 +24,7 @@
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
namespace mlir {
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index b3d085bfff1af9..d0a780c1b0a0bd 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -21,6 +21,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/SubsetOpInterface.td"
include "mlir/Interfaces/ViewLikeInterface.td"
def SCF_Dialect : Dialect {
@@ -608,6 +609,11 @@ def ForallOp : SCF_Op<"forall", [
// Declare the shared_outs as inits/outs to DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
+
+ /// Returns the only user of the block argument within forall.in_parallel
+ /// which is a tensor.parallel_insert_slice. Returns failure if it finds
+ /// more than one such users.
+ FailureOr<::mlir::SubsetInsertionOpInterface> getStoreOpUser(BlockArgument bbArg);
}];
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 355cfc8b3ee626..aab2d14dd2fe88 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1416,6 +1416,31 @@ InParallelOp ForallOp::getTerminator() {
return cast<InParallelOp>(getBody()->getTerminator());
}
+FailureOr<SubsetInsertionOpInterface>
+ForallOp::getStoreOpUser(BlockArgument bbArg) {
+ Value::user_range users = bbArg.getUsers();
+ bool foundUser = false;
+ SubsetInsertionOpInterface storeOp = nullptr;
+ for (Operation *userOp : users) {
+ if (auto parallelInsertSliceOp =
+ dyn_cast<tensor::ParallelInsertSliceOp>(userOp);
+ parallelInsertSliceOp && isa<InParallelOp>(userOp->getParentOp())) {
+ // Return failure in case we find more than one user of the block argument
+ // within scf.forall.in_parallel.
+ if (foundUser) {
+ return failure();
+ }
+ storeOp = cast<SubsetInsertionOpInterface>(userOp);
+ if (storeOp.getDestinationOperand().get() == bbArg) {
+ foundUser = true;
+ }
+ }
+ }
+ if (foundUser)
+ return storeOp;
+ return failure();
+}
+
std::optional<Value> ForallOp::getSingleInductionVar() {
if (getRank() != 1)
return std::nullopt;
@@ -1555,6 +1580,8 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
/// 3. The order of the <STORE OP WITH DESTINATION *> can be arbitrary
/// within scf.forall.in_parallel - the code below takes care of this
/// by traversing the uses of the corresponding iter arg.
+/// 4. TODO(avarma): Generalize it for other store ops. Currently it
+/// handles tensor.parallel_insert_slice ops only.
///
/// Example of second case :-
/// INPUT:
@@ -1585,18 +1612,6 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
using OpRewritePattern<ForallOp>::OpRewritePattern;
- /// Utility function that checks if a candidate value satisifies any of the
- /// conditions (see above doc comment) to make it viable for folding away.
- static bool isCandidateValueToDelete(Value result, BlockArgument blockArg) {
- if (result.use_empty()) {
- return true;
- }
- Value::user_range users = blockArg.getUsers();
- return llvm::all_of(users, [&](Operation *user) {
- return !isa<SubsetInsertionOpInterface>(user);
- });
- }
-
LogicalResult matchAndRewrite(ForallOp forallOp,
PatternRewriter &rewriter) const final {
scf::InParallelOp terminatorOp = forallOp.getTerminator();
@@ -1617,7 +1632,7 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
// Step 1: For a given i-th result of scf.forall, check the following :-
// a. If it has any use.
// b. If the corresponding iter argument is being modified within
- // the loop.
+ // the loop, i.e. fetch a unique store op.
//
// Based on the check we maintain the following :-
// a. `resultToDelete` - i-th result of scf.forall that'll be
@@ -1636,7 +1651,7 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
for (OpResult result : forallOp.getResults()) {
OpOperand *opOperand = forallOp.getTiedOpOperand(result);
BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
- if (isCandidateValueToDelete(result, blockArg)) {
+ if (result.use_empty() || failed(forallOp.getStoreOpUser(blockArg))) {
resultToDelete.push_back(result);
mapping.map(blockArg, opOperand->get());
} else {
@@ -1653,28 +1668,20 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
// Step 2: For the the i-th result, do the following :-
// a. Fetch the corresponding BlockArgument.
- // b. Look for an op within scf.forall.in_parallel whose destination
- // operand is the BlockArgument fetched in step a.
+ // b. Look for a unique store op (currently
+ // tensor.parallel_insert_slice) with the BlockArgument as its
+ // Destination operand.
// c. Remove the operation fetched in b.
- // d. For any use of the BlockArgument in the body of the scf.forall
- // replace it with the corresponding Output value.
for (OpResult result : resultToDelete) {
OpOperand *opOperand = forallOp.getTiedOpOperand(result);
BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
- Value::user_range users = blockArg.getUsers();
- Operation *terminatorOperationToDelete = nullptr;
- for (Operation *user : users) {
- if (auto subsetInsertionOpInterfaceOp =
- dyn_cast<SubsetInsertionOpInterface>(user)) {
- if (subsetInsertionOpInterfaceOp.getDestinationOperand().get() ==
- blockArg) {
- terminatorOperationToDelete = subsetInsertionOpInterfaceOp;
- break;
- }
- }
+ FailureOr<SubsetInsertionOpInterface> storeOp =
+ forallOp.getStoreOpUser(blockArg);
+ rewriter.replaceAllUsesWith(blockArg, opOperand->get());
+ if (failed(storeOp)) {
+ continue;
}
- if (terminatorOperationToDelete)
- rewriter.eraseOp(terminatorOperationToDelete);
+ rewriter.eraseOp(storeOp.value());
}
// Step 3. Create a new scf.forall op with the new shared_outs' operands
@@ -1684,11 +1691,25 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
forallOp.getMapping());
- // Step 4. Clone the region of the old scf.forall into the newly created
- // scf.forall using the IRMapping formed in Step 1.
- newforallOp.getBodyRegion().getBlocks().clear();
- rewriter.cloneRegionBefore(forallOp.getRegion(), newforallOp.getRegion(),
- newforallOp.getRegion().begin(), mapping);
+ // Step 4. Merge the block of the old scf.forall into the newly created
+ // scf.forall using the new set of arguments.
+ Block *loopBody = forallOp.getBody();
+ Block *newLoopBody = newforallOp.getBody();
+ ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
+ SmallVector<Value> newBlockArgs =
+ llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
+ [](BlockArgument b) -> Value { return b; });
+ Block::BlockArgListType newSharedOutsArgs = newforallOp.getRegionOutArgs();
+ unsigned index = 0;
+ for (BlockArgument bbArg : forallOp.getRegionOutArgs()) {
+ if (mapping.contains(bbArg)) {
+ newBlockArgs.push_back(mapping.lookup(bbArg));
+ } else {
+ newBlockArgs.push_back(newSharedOutsArgs[index++]);
+ }
+ }
+ rewriter.eraseOp(newforallOp.getTerminator());
+ rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
// Step 5. Replace the uses of result of old scf.forall with that of the new
// scf.forall.
More information about the Mlir-commits
mailing list