[Mlir-commits] [mlir] [MLIR][SCF] Add canonicalization pattern to fold away iter args of scf.forall (PR #90189)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 1 16:51:12 PDT 2024

@@ -1509,6 +1510,203 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
+/// The following canonicalization pattern folds the iter arguments of
+/// scf.forall op if :-
+/// 1. The corresponding result has zero uses.
+/// 2. The iter argument is NOT being modified within the loop body.
+/// uses.
+/// Example of first case :-
+///  INPUT:
+///   %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
+///            {
+///                ...
+///                <SOME USE OF %arg0>
+///                <SOME USE OF %arg1>
+///                <SOME USE OF %arg2>
+///                ...
+///                scf.forall.in_parallel {
+///                    <STORE OP WITH DESTINATION %arg1>
+///                    <STORE OP WITH DESTINATION %arg0>
+///                    <STORE OP WITH DESTINATION %arg2>
+///                }
+///             }
+///   return %res#1
+///  OUTPUT:
+///   %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
+///            {
+///                ...
+///                <SOME USE OF %a>
+///                <SOME USE OF %new_arg0>
+///                <SOME USE OF %c>
+///                ...
+///                scf.forall.in_parallel {
+///                    <STORE OP WITH DESTINATION %new_arg0>
+///                }
+///             }
+///   return %res
+/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
+///          scf.forall is replaced by their corresponding operands.
+///       2. The canonicalization assumes that there are no <STORE OP WITH
+///          DESTINATION *> ops within the body of the scf.forall except within
+///          scf.forall.in_parallel terminator.
+///       3. The order of the <STORE OP WITH DESTINATION *> can be arbitrary
+///          within scf.forall.in_parallel - the code below takes care of this
+///          by traversing the uses of the corresponding iter arg.
+/// Example of second case :-
+///  INPUT:
+///   %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
+///            {
+///                ...
+///                <SOME USE OF %arg0>
+///                <SOME USE OF %arg1>
+///                ...
+///                scf.forall.in_parallel {
+///                    <STORE OP WITH DESTINATION %arg1>
+///                }
+///             }
+///   return %res#0, %res#1
+///  OUTPUT:
+///   %res = scf.forall ... shared_outs(%new_arg0 = %b)
+///            {
+///                ...
+///                <SOME USE OF %a>
+///                <SOME USE OF %new_arg0>
+///                ...
+///                scf.forall.in_parallel {
+///                    <STORE OP WITH DESTINATION %new_arg0>
+///                }
+///             }
+///   return %a, %res
+struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
+  using OpRewritePattern<ForallOp>::OpRewritePattern;
+  /// Utility function that checks if a candidate value satisifies any of the
+  /// conditions (see above doc comment) to make it viable for folding away.
+  static bool isCandidateValueToDelete(Value result, BlockArgument blockArg) {
+    if (result.use_empty()) {
+      return true;
+    }
+    Value::user_range users = blockArg.getUsers();
+    return llvm::all_of(users, [&](Operation *user) {
+      return !isa<SubsetInsertionOpInterface>(user);
MaheshRavishankar wrote:

I dont know if we can get to a general solution right away. Its hard to handle the generality of all of possible "insertion-like" ops. I would go the other way. To start with, I would check that the only use of the iter_args is in `tensor.insert_in_parallel` ops within the `scf.forall.in_parallel` and only then drop the result (and corresponding iter_arg and `tensor.insert_in_parallel`. We dont have any semantics for any other case right now and not worth generalizing in a vaccum. Lets start small but well-defined and go from there


More information about the Mlir-commits mailing list