[Mlir-commits] [mlir] [MLIR][SCF] Add canonicalization pattern to fold away iter args of scf.forall (PR #90189)
Abhishek Varma
llvmlistbot at llvm.org
Sun Apr 28 22:43:51 PDT 2024
================
@@ -1509,6 +1510,203 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
}
};
+/// The following canonicalization pattern folds the iter arguments of
+/// scf.forall op if :-
+/// 1. The corresponding result has zero uses.
+/// 2. The iter argument is NOT being modified within the loop body.
+/// uses.
+///
+/// Example of first case :-
+/// INPUT:
+/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
+/// {
+/// ...
+/// <SOME USE OF %arg0>
+/// <SOME USE OF %arg1>
+/// <SOME USE OF %arg2>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %arg1>
+/// <STORE OP WITH DESTINATION %arg0>
+/// <STORE OP WITH DESTINATION %arg2>
+/// }
+/// }
+/// return %res#1
+///
+/// OUTPUT:
+/// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
+/// {
+/// ...
+/// <SOME USE OF %a>
+/// <SOME USE OF %new_arg0>
+/// <SOME USE OF %c>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %new_arg0>
+/// }
+/// }
+/// return %res
+///
+/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
+/// scf.forall is replaced by their corresponding operands.
+/// 2. The canonicalization assumes that there are no <STORE OP WITH
+/// DESTINATION *> ops within the body of the scf.forall except within
+/// scf.forall.in_parallel terminator.
+/// 3. The order of the <STORE OP WITH DESTINATION *> can be arbitrary
+/// within scf.forall.in_parallel - the code below takes care of this
+/// by traversing the uses of the corresponding iter arg.
+///
+/// Example of second case :-
+/// INPUT:
+/// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
+/// {
+/// ...
+/// <SOME USE OF %arg0>
+/// <SOME USE OF %arg1>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %arg1>
+/// }
+/// }
+/// return %res#0, %res#1
+///
+/// OUTPUT:
+/// %res = scf.forall ... shared_outs(%new_arg0 = %b)
+/// {
+/// ...
+/// <SOME USE OF %a>
+/// <SOME USE OF %new_arg0>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %new_arg0>
+/// }
+/// }
+/// return %a, %res
+struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
+ using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+ /// Utility function that checks if a candidate value satisifies any of the
+ /// conditions (see above doc comment) to make it viable for folding away.
+ static bool isCandidateValueToDelete(Value result, BlockArgument blockArg) {
+ if (result.use_empty()) {
+ return true;
+ }
+ Value::user_range users = blockArg.getUsers();
+ return llvm::all_of(users, [&](Operation *user) {
+ return !isa<SubsetInsertionOpInterface>(user);
----------------
Abhishek-Varma wrote:
Hi @matthias-springer . Thanks for the review and suggestion! I'm writing below a few corner case which I feel might arise. Please let me know where am I going wrong with my understanding and what's the best way to deal with the same and I shall do that. :)
> There may be ops that are inserting at a subset but do not implement the interface.
Oh okay. I wasn't aware of this.
Regarding `tensor.cast` - the check I've added would only return true if at least one `SubsetInsertionOpInterface` is found for the bbArg. Even if `tensor.cast`'s result may be passed into a `SubsetInsertionOpInterface`, it won't have the same bbArg (the element type would be different), so it would be okay.
> But I think there's a simpler solution: if a shared_outs bbArg is not used as the "destination" of an op inside of the scf.forall.in_parallel terminator, it should be safe to use the init value inside the loop instead. Can you give that a try?
That is definitely correct and a simpler check, but I was trying to address the following case as well :-
```
scf.forall ... shared_outs(%arg0 = %a)
{
...
<SOME USE OF %arg0>
...
%x = tensor.insert_slice <some_val> into %arg0 (or some unregistered op that inserts a value into a subset)
...
<some use of %x>
...
scf.forall.in_parallel {
<STORE OP WITH DESTINATION %arg0 or some other bbArg> (currently `tensor.parallel_insert_slice` and `tensor.insert_slice` do that)
}
}
```
Therefore, if I'm only checking within `scf.forall.in_parallel`, it won't cater to the above case.
So, three things unknown to me at this point are :-
1. How to deal with the unregistered ops which might appear anywhere in the loop body and is inserting a value in the bbArg?
2. Is there any other way to check if a value is used as a "destination" of an op besides `SubsetInsertionOpInterface` ?
3. Why do we not require subset insertion ops to implement the `SubsetInsertionOpInterface` ?
https://github.com/llvm/llvm-project/pull/90189
More information about the Mlir-commits
mailing list