[Mlir-commits] [mlir] [mlir][Linalg]: Add rewrite pattern to fuse fill with reduce operations (PR #125401)
Kunwar Grover
llvmlistbot at llvm.org
Mon Feb 3 01:45:20 PST 2025
Groverkss wrote:
> Hey @MaheshRavishankar , We are using tile and fuse transform in addition to the regular linalg fusion. When lowered to loops, I would like to acheive the following pseudo-code:
>
> ```
> for (int i = 0; i < N; ++i) {
> int sum = 0;
> for (j = 0; j < M; ++j) {
> int x = load[i][j];
> sum += x;
> }
> }
> ```
>
> I can lower the linalg.generic of fill and reduce into loops and do it then but at that point the flow is much more complicated to identify the pattern of fill + reduce. This pattern is useful for us although it causes none fully nested loops. Moreover, our HW know how to handle such calculations although it is not fully nested loops.
>
> I do have some patterns to optimize the none nested loops after the lowering but it is more HW specific.
>
> I think that some other people might use this tranformation pattern as well as I do.
> Hey @MaheshRavishankar , We are using tile and fuse transform in addition to the regular linalg fusion. When lowered to loops, I would like to acheive the following pseudo-code:
>
> ```
> for (int i = 0; i < N; ++i) {
> int sum = 0;
> for (j = 0; j < M; ++j) {
> int x = load[i][j];
> sum += x;
> }
> }
> ```
>
> I can lower the linalg.generic of fill and reduce into loops and do it then but at that point the flow is much more complicated to identify the pattern of fill + reduce. This pattern is useful for us although it causes none fully nested loops. Moreover, our HW know how to handle such calculations although it is not fully nested loops.
>
> I do have some patterns to optimize the none nested loops after the lowering but it is more HW specific.
>
> I think that some other people might use this tranformation pattern as well as I do.
I don't think doing this at the linalg.generic level is correct, because you cannot assume a single iteration order. IIUC, what you are trying to do here is:
```mlir
%empty = linalg.fill
linalg.generic ins(...) outs(%empty) { iterator_types = [parallel, reduction] }
```
1. TileAndFuse along the reduction dimension (The fill will not fuse, because it doesnt have the reduction dimension):
```
%empty = linalg.fill
scf.for %j = 0 to ... init_args(%arg0 = %empty) {
%out = linalg.generic ins(...) outs(%arg0)
yield %out
}
```
2. TileAndFuse along the reduction dimension (The fill will fuse, because it has the parallel dimension):
```mlir
scf.for %i = 0 to ... {
%empty = linalg.fill
scf.for %j = 0 to ... init_args(%arg0 = %empty) {
%out = linalg.generic ins(...) outs(%arg0)
yield %out
}
}
```
3. This is the form you were looking for. Now, if you still want a perfectly nested loop form, you can write a loop sinking pass. You can sink any init_args (with value based semantics atleast, i.e. tensors in this case) with an if condition:
```mlir
%em = tensor.empty()
scf.for %i = 0 to ... {
scf.for %j = 0 to ... init_args(%arg0 = %em) {
%filled = scf.if (%j == 0) init_args(%arg1 = %arg0) {
%fill = linalg.fill outs(%arg1)
yield %fill
} else {
yield %arg1
}
linalg.generic ins(...) outs(%filled)
}
}
```
https://github.com/llvm/llvm-project/pull/125401
More information about the Mlir-commits
mailing list