[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 6 19:18:40 PDT 2024
Yun-Fly wrote:
@MaheshRavishankar Thanks for you explanation! I can get both your points now.
> Lets start with your example above. I think your input is
>
> ```
> %0 = linalg.fill .. outs(%empty)
> %1 = linalg.matmul ... outs(%0)
> %2 = linalg.add (..., %1)
> ```
I agree with you if input starts from this way, which couples tiling and fusion step by step, recursively call `tileAndFuse` is another good option to . However, what if `matmul` has already been completely tiled into nested loop before fusion? For example, developers have some hand-writing optimized `matmul` template or kernel specific for GPU or CPU, actually decoupling with the fusion stage, Then, they want to fuse post `add` op. In fact, it is the initial motivation of this patch as I illustrated at the top description.
```
%0 = linalg.fill ... outs(%empty)
%1 = scf.forall ... shared_outs(%arg0 = %0) {
%2 = tensor.extract_slice %arg0[...]
%3 = scf.for ... iter_args(%arg1=%2) {
%4 = scf.for ... iter_args(%arg2=%3) {
%5 = tensor.extract_slice %arg2[...]
%6 = linalg.matmul ... outs(%5)
%7 = tensor.insert_slice %6
scf.yield %7
}
scf.yield %4
}
scf.forall.in_parallel {
tensor.insert_in_parallel %3 into %arg0
}
}
%2 = linalg.add (.., %1)
```
With current implementation, although it is possible to fuse `add` at outermost `scf.forall(%1)`, it seems hard to recursively fuse it into next level loop `scf.for(%3)` without any `tensor.insert_slice`.
> This is changing things much more than I would expect.
Compared with fusing consumer from outer to inner step by step with multiple application, the overall logic of this patch can be simplified into three steps for your review:
1. enhance `getUntiledConsumer` to get real consumer of given candidate slice.
2. use what #88712 have done to fuse real consumer into parent loop of candidate slice.
3. restore OUTER LOOPs by the way similar to existing `addInitOperandsToLoopNest` method.
As you may see, only `1` an `3` is newly added. The other change may just involve some code refactor for better reuse.
https://github.com/llvm/llvm-project/pull/94190
More information about the Mlir-commits
mailing list