[Mlir-commits] [mlir] [MLIR][SCF] Add canonicalization pattern to fold away iter args of scf.forall (PR #90189)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 2 15:00:42 PDT 2024
================
@@ -1509,6 +1535,199 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
}
};
+/// The following canonicalization pattern folds the iter arguments of
+/// scf.forall op if :-
+/// 1. The corresponding result has zero uses.
+/// 2. The iter argument is NOT being modified within the loop body.
+/// uses.
+///
+/// Example of first case :-
+/// INPUT:
+/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
+/// {
+/// ...
+/// <SOME USE OF %arg0>
+/// <SOME USE OF %arg1>
+/// <SOME USE OF %arg2>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %arg1>
+/// <STORE OP WITH DESTINATION %arg0>
+/// <STORE OP WITH DESTINATION %arg2>
+/// }
+/// }
+/// return %res#1
+///
+/// OUTPUT:
+/// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
+/// {
+/// ...
+/// <SOME USE OF %a>
+/// <SOME USE OF %new_arg0>
+/// <SOME USE OF %c>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %new_arg0>
+/// }
+/// }
+/// return %res
+///
+/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
+/// scf.forall is replaced by their corresponding operands.
+/// 2. The canonicalization assumes that there are no <STORE OP WITH
+/// DESTINATION *> ops within the body of the scf.forall except within
+/// scf.forall.in_parallel terminator.
+/// 3. The order of the <STORE OP WITH DESTINATION *> can be arbitrary
+/// within scf.forall.in_parallel - the code below takes care of this
+/// by traversing the uses of the corresponding iter arg.
+/// 4. TODO(avarma): Generalize it for other store ops. Currently it
+/// handles tensor.parallel_insert_slice ops only.
+///
+/// Example of second case :-
+/// INPUT:
+/// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
+/// {
+/// ...
+/// <SOME USE OF %arg0>
+/// <SOME USE OF %arg1>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %arg1>
+/// }
+/// }
+/// return %res#0, %res#1
+///
+/// OUTPUT:
+/// %res = scf.forall ... shared_outs(%new_arg0 = %b)
+/// {
+/// ...
+/// <SOME USE OF %a>
+/// <SOME USE OF %new_arg0>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %new_arg0>
+/// }
+/// }
+/// return %a, %res
+struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
+ using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ForallOp forallOp,
+ PatternRewriter &rewriter) const final {
+ scf::InParallelOp terminatorOp = forallOp.getTerminator();
+ SmallVector<Operation *> yieldingOps = llvm::map_to_vector(
+ terminatorOp.getYieldingOps(), [](Operation &op) { return &op; });
+
+ // The following check should indeed be part of SCF::ForallOp::verify.
+ SmallVector<SubsetInsertionOpInterface> subsetInsertionOpInterfaceOps;
+ for (Operation *op : yieldingOps) {
+ if (auto subsetInsertionOpInterfaceOp =
+ dyn_cast<SubsetInsertionOpInterface>(op)) {
+ subsetInsertionOpInterfaceOps.push_back(subsetInsertionOpInterfaceOp);
+ continue;
+ }
+ return failure();
+ }
+
+ // Step 1: For a given i-th result of scf.forall, check the following :-
+ // a. If it has any use.
+ // b. If the corresponding iter argument is being modified within
+ // the loop, i.e. fetch a unique store op.
+ //
+ // Based on the check we maintain the following :-
+ // a. `resultToDelete` - i-th result of scf.forall that'll be
+ // deleted.
+ // b. `resultToReplace` - i-th result of the old scf.forall
+ // whose uses will be replaced by the new scf.forall.
+ // c. `newOuts` - the shared_outs' operand of the new scf.forall
+ // corresponding to the i-th result with at least one use.
+ // d. `mapping` - mapping the old iter block argument of scf.forall
+ // with the corresponding shared_outs' operand. This will be
+ // used when creating a new scf.forall op.
+ SmallVector<OpResult> resultToDelete;
----------------
MaheshRavishankar wrote:
Make this a `SetVector<OpResult>` that way later on when you want to get the basic block arguments for the `mergeBlocks` call you can use it to see if this is a deleted result.
https://github.com/llvm/llvm-project/pull/90189
More information about the Mlir-commits
mailing list