[Mlir-commits] [mlir] [mlir][Linalg]: Add rewrite pattern to fuse fill with reduce operations (PR #125401)

Kunwar Grover llvmlistbot at llvm.org
Mon Feb 3 01:45:20 PST 2025


Groverkss wrote:

> Hey @MaheshRavishankar , We are using tile and fuse transform in addition to the regular linalg fusion. When lowered to loops, I would like to acheive the following pseudo-code:
> 
> ```
> for (int i = 0; i < N; ++i) {
>     int sum = 0;
>     for (j = 0; j < M; ++j) {
>         int x = load[i][j];
>         sum += x;
>     }
> }
> ```
> 
> I can lower the linalg.generic of fill and reduce into loops and do it then but at that point the flow is much more complicated to identify the pattern of fill + reduce. This pattern is useful for us although it causes none fully nested loops. Moreover, our HW know how to handle such calculations although it is not fully nested loops.
> 
> I do have some patterns to optimize the none nested loops after the lowering but it is more HW specific.
> 
> I think that some other people might use this tranformation pattern as well as I do.


> Hey @MaheshRavishankar , We are using tile and fuse transform in addition to the regular linalg fusion. When lowered to loops, I would like to acheive the following pseudo-code:
> 
> ```
> for (int i = 0; i < N; ++i) {
>     int sum = 0;
>     for (j = 0; j < M; ++j) {
>         int x = load[i][j];
>         sum += x;
>     }
> }
> ```
> 
> I can lower the linalg.generic of fill and reduce into loops and do it then but at that point the flow is much more complicated to identify the pattern of fill + reduce. This pattern is useful for us although it causes none fully nested loops. Moreover, our HW know how to handle such calculations although it is not fully nested loops.
> 
> I do have some patterns to optimize the none nested loops after the lowering but it is more HW specific.
> 
> I think that some other people might use this tranformation pattern as well as I do.

I don't think doing this at the linalg.generic level is correct, because you cannot assume a single iteration order. IIUC, what you are trying to do here is:

```mlir
%empty = linalg.fill
linalg.generic ins(...) outs(%empty) { iterator_types = [parallel, reduction] }
```

1. TileAndFuse along the reduction dimension (The fill will not fuse, because it doesnt have the reduction dimension):

```
%empty = linalg.fill
scf.for %j = 0 to ... init_args(%arg0 = %empty) {
  %out = linalg.generic ins(...) outs(%arg0)
  yield %out
}
```

2. TileAndFuse along the reduction dimension (The fill will fuse, because it has the parallel dimension):

```mlir
scf.for %i = 0 to ... {
   %empty = linalg.fill
   scf.for %j = 0 to ... init_args(%arg0 = %empty) {
    %out = linalg.generic ins(...) outs(%arg0)
    yield %out
  }
}
```

3. This is the form you were looking for. Now, if you still want a perfectly nested loop form, you can write a loop sinking pass. You can sink any init_args (with value based semantics atleast, i.e. tensors in this case) with an if condition:

```mlir
%em = tensor.empty()
scf.for %i = 0 to ... {
   scf.for %j = 0 to ... init_args(%arg0 = %em) {
    %filled = scf.if (%j == 0) init_args(%arg1 = %arg0) {
       %fill = linalg.fill outs(%arg1)
       yield %fill
    } else {
      yield %arg1
    }
    linalg.generic ins(...) outs(%filled)
  }
}
```

https://github.com/llvm/llvm-project/pull/125401


More information about the Mlir-commits mailing list