[Mlir-commits] [mlir] [MLIR][SCF] Add canonicalization pattern to fold away iter args of scf.forall (PR #90189)

Matthias Springer llvmlistbot at llvm.org
Fri May 3 02:57:06 PDT 2024


================
@@ -1509,6 +1510,203 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
   }
 };
 
+/// The following canonicalization pattern folds the iter arguments of
+/// scf.forall op if :-
+/// 1. The corresponding result has zero uses.
+/// 2. The iter argument is NOT being modified within the loop body.
+/// uses.
+///
+/// Example of first case :-
+///  INPUT:
+///   %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
+///            {
+///                ...
+///                <SOME USE OF %arg0>
+///                <SOME USE OF %arg1>
+///                <SOME USE OF %arg2>
+///                ...
+///                scf.forall.in_parallel {
+///                    <STORE OP WITH DESTINATION %arg1>
+///                    <STORE OP WITH DESTINATION %arg0>
+///                    <STORE OP WITH DESTINATION %arg2>
+///                }
+///             }
+///   return %res#1
+///
+///  OUTPUT:
+///   %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
+///            {
+///                ...
+///                <SOME USE OF %a>
+///                <SOME USE OF %new_arg0>
+///                <SOME USE OF %c>
+///                ...
+///                scf.forall.in_parallel {
+///                    <STORE OP WITH DESTINATION %new_arg0>
+///                }
+///             }
+///   return %res
+///
+/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
+///          scf.forall is replaced by their corresponding operands.
+///       2. The canonicalization assumes that there are no <STORE OP WITH
+///          DESTINATION *> ops within the body of the scf.forall except within
+///          scf.forall.in_parallel terminator.
+///       3. The order of the <STORE OP WITH DESTINATION *> can be arbitrary
+///          within scf.forall.in_parallel - the code below takes care of this
+///          by traversing the uses of the corresponding iter arg.
+///
+/// Example of second case :-
+///  INPUT:
+///   %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
+///            {
+///                ...
+///                <SOME USE OF %arg0>
+///                <SOME USE OF %arg1>
+///                ...
+///                scf.forall.in_parallel {
+///                    <STORE OP WITH DESTINATION %arg1>
+///                }
+///             }
+///   return %res#0, %res#1
+///
+///  OUTPUT:
+///   %res = scf.forall ... shared_outs(%new_arg0 = %b)
+///            {
+///                ...
+///                <SOME USE OF %a>
+///                <SOME USE OF %new_arg0>
+///                ...
+///                scf.forall.in_parallel {
+///                    <STORE OP WITH DESTINATION %new_arg0>
+///                }
+///             }
+///   return %a, %res
+struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
+  using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+  /// Utility function that checks if a candidate value satisifies any of the
+  /// conditions (see above doc comment) to make it viable for folding away.
+  static bool isCandidateValueToDelete(Value result, BlockArgument blockArg) {
+    if (result.use_empty()) {
+      return true;
+    }
+    Value::user_range users = blockArg.getUsers();
+    return llvm::all_of(users, [&](Operation *user) {
+      return !isa<SubsetInsertionOpInterface>(user);
----------------
matthias-springer wrote:

> How to deal with the unregistered ops which might appear anywhere in the loop body and is inserting a value in the bbArg?

If you just look for the iter_arg being used as a destination in the terminator, it does not matter what the remaining loop body looks like. If there is an insertion into a tensor that is defined outside of the loop, then One-Shot Bufferize will allocate a thread-local buffer copy.

> Is there any other way to check if a value is used as a "destination" of an op besides SubsetInsertionOpInterface ?

I think there is a special interface for ops that can appear in the in_parallel terminator region. You could query that interface. Hopefully the destination can be queried from it. If not, you can add an interface method for that. For the moment you could also just hard-code the implementation to parallel_insert_slice because that's the only terminator that we support at the moment anyway.

> Why do we not require subset insertion ops to implement the SubsetInsertionOpInterface ?

One reason is that we could not handle unregistered ops correctly. Maybe there's a way to support that safely… it’s kind of like the MemoryEffectsOpInterface: if an op does not implement that interface, it does not mean that there is no side effect; it just means that we don’t know the side effects.

Assuming we only support parallel_insert_slice, does that handle all cases that you were thinking of? (I think it will work in the example that you posted.)



https://github.com/llvm/llvm-project/pull/90189


More information about the Mlir-commits mailing list