[Mlir-commits] [mlir] [MLIR][SCF] Add canonicalization pattern to fold away iter args of scf.forall (PR #90189)
Matthias Springer
llvmlistbot at llvm.org
Mon May 6 04:32:51 PDT 2024
================
@@ -1509,6 +1522,177 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
}
};
+/// The following canonicalization pattern folds the iter arguments of
+/// scf.forall op if :-
+/// 1. The corresponding result has zero uses.
+/// 2. The iter argument is NOT being modified within the loop body.
+/// uses.
+///
+/// Example of first case :-
+/// INPUT:
+/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
+/// {
+/// ...
+/// <SOME USE OF %arg0>
+/// <SOME USE OF %arg1>
+/// <SOME USE OF %arg2>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %arg1>
+/// <STORE OP WITH DESTINATION %arg0>
+/// <STORE OP WITH DESTINATION %arg2>
+/// }
+/// }
+/// return %res#1
+///
+/// OUTPUT:
+/// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
+/// {
+/// ...
+/// <SOME USE OF %a>
+/// <SOME USE OF %new_arg0>
+/// <SOME USE OF %c>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %new_arg0>
+/// }
+/// }
+/// return %res
+///
+/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
+/// scf.forall is replaced by their corresponding operands.
+/// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
+/// of the scf.forall besides within scf.forall.in_parallel terminator,
+/// this canonicalization remains valid. For more details, please refer
+/// to :
+/// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
+/// 3. TODO(avarma): Generalize it for other store ops. Currently it
+/// handles tensor.parallel_insert_slice ops only.
+///
+/// Example of second case :-
+/// INPUT:
+/// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
+/// {
+/// ...
+/// <SOME USE OF %arg0>
+/// <SOME USE OF %arg1>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %arg1>
+/// }
+/// }
+/// return %res#0, %res#1
+///
+/// OUTPUT:
+/// %res = scf.forall ... shared_outs(%new_arg0 = %b)
+/// {
+/// ...
+/// <SOME USE OF %a>
+/// <SOME USE OF %new_arg0>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %new_arg0>
+/// }
+/// }
+/// return %a, %res
+struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
+ using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ForallOp forallOp,
+ PatternRewriter &rewriter) const final {
+ // Step 1: For a given i-th result of scf.forall, check the following :-
+ // a. If it has any use.
+ // b. If the corresponding iter argument is being modified within
+ // the loop, i.e. has at least one store op with the iter arg as
+ // its destination operand. For this we use
+ // ForallOp::getStoreOpUser(iter_arg).
+ //
+ // Based on the check we maintain the following :-
+ // a. `resultToDelete` - i-th result of scf.forall that'll be
+ // deleted.
+ // b. `resultToReplace` - i-th result of the old scf.forall
+ // whose uses will be replaced by the new scf.forall.
+ // c. `newOuts` - the shared_outs' operand of the new scf.forall
+ // corresponding to the i-th result with at least one use.
+ SetVector<OpResult> resultToDelete;
+ SmallVector<Value> resultToReplace;
+ SmallVector<Value> newOuts;
+ for (OpResult result : forallOp.getResults()) {
+ OpOperand *opOperand = forallOp.getTiedOpOperand(result);
+ BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
+ if (result.use_empty() || forallOp.getStoreOpUser(blockArg).empty()) {
+ resultToDelete.insert(result);
+ } else {
+ resultToReplace.push_back(result);
+ newOuts.push_back(opOperand->get());
+ }
+ }
+
+ // Return early if all results of scf.forall has at least one use and being
+ // modified within the loop.
+ if (resultToDelete.empty()) {
+ return failure();
+ }
+
+ // Step 2: For the the i-th result, do the following :-
+ // a. Fetch the corresponding BlockArgument.
+ // b. Look for store ops (currently tensor.parallel_insert_slice)
+ // with the BlockArgument as its destination operand.
+ // c. Remove the operations fetched in b.
+ for (OpResult result : resultToDelete) {
+ OpOperand *opOperand = forallOp.getTiedOpOperand(result);
+ BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
+ SmallVector<Operation *> storeOps = forallOp.getStoreOpUser(blockArg);
+ for (Operation *storeOp : storeOps) {
+ rewriter.eraseOp(storeOp);
+ }
+ }
+
+ // Step 3. Create a new scf.forall op with the new shared_outs' operands
+ // fetched earlier
+ auto newforallOp = rewriter.create<scf::ForallOp>(
+ forallOp.getLoc(), forallOp.getMixedLowerBound(),
+ forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
+ forallOp.getMapping());
+
+ // Step 4. Merge the block of the old scf.forall into the newly created
+ // scf.forall using the new set of arguments.
+ Block *loopBody = forallOp.getBody();
+ Block *newLoopBody = newforallOp.getBody();
+ ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
+ SmallVector<Value> newBlockArgs =
+ llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
+ [](BlockArgument b) -> Value { return b; });
+ Block::BlockArgListType newSharedOutsArgs = newforallOp.getRegionOutArgs();
+ unsigned index = 0;
+ for (OpResult result : forallOp.getResults()) {
+ if (resultToDelete.count(result)) {
+ newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
+ } else {
+ newBlockArgs.push_back(newSharedOutsArgs[index++]);
+ }
+ }
+ rewriter.eraseOp(newforallOp.getTerminator());
+ rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
+
+ // Step 5. Replace the uses of result of old scf.forall with that of the new
+ // scf.forall.
+ for (auto &&[oldResult, newResult] :
+ llvm::zip(resultToReplace, newforallOp->getResults())) {
+ rewriter.replaceAllUsesWith(oldResult, newResult);
+ }
+
+ // Step 6. Replace the uses of those values that either has no use or are
+ // not being modified within the loop with the corresponding
+ // OpOperand.
+ for (OpResult oldResult : resultToDelete) {
----------------
matthias-springer wrote:
nit: trivial braces not needed
https://github.com/llvm/llvm-project/pull/90189
More information about the Mlir-commits
mailing list