[Mlir-commits] [mlir] [mlir][Linalg]: Add rewrite pattern to fuse fill with reduce operations (PR #125401)
Aviad Cohen
llvmlistbot at llvm.org
Tue Feb 4 06:47:56 PST 2025
AviadCo wrote:
@MaheshRavishankar @Groverkss
I really appreciate your deep review and suggestions. I agree that this transform as it is should not be merged.
I think that linalg could benefit from optinal "order" attribute (in our case we are not allowed to change order on floats for example and must use them serially).
> %cst = arith.constant 0.0 : f32
> %empty = tensor.empty(%N) : tensor<?xf32>
> %result = scf.for %iv0 = 0 to %N step 1 outs(%init = %empty) {
> %outs = tensor.extract_slice %init[%iv0][1][1]: tensor<?xf32> to tensor<1xf32>
> %fill = linalg.fill ins(%cst : f32) outs(%outs : tensor<?xf32>) -> tensor<?xf32>
> %slice = tensor.extract_slice %load[%iv0, 0][1, %N][1, 1]: tensor<?x?xf32> to tensor<1x?xf32>
> %reduction = scf.for %iv1 = 0 to %M step 1 outs(%init0 = %fill) {
> %slice0 = tensor.extract_slice %slice[0, %iv1][1, 1][1, 1] : tensor<1x?xf32> to tensor<1x1xf32>
> %generic = linalg.generic {
> iterator_types = ["parallel", "reduction"],
> indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>]}
> ins(%slice0 : tensor<1x1xf32>) outs(%init0 : tensor<1xf32>) {
> ^bb0(%b0 : f32, %b1 : f32):
> %0 = arith.addf %b0, %b1 : f32
> linalg.yield %0 : f32
> } -> tensor<1xf32>
> %inserted0 = tensor.insert_slice %generic into %init0[0][1][1] : tensor<1xf32> into tensor<1xf32>
> scf.yield %inserted0 : tensor<1xf32>
> }
> %inserted = tensor.insert_slice %generic into %init[%iv0][1][1] : tensor<1xf32> into tensor<?xf32>
> scf.yield %inserted : tensor<1xf32>
> } -> tensor<?xf32>
Unfortunately, this flow makes the final linalg.generic too naive (works on one element) and our general flow depends on the fact the `linalg.generic` is the actual heavy compute.
We do use FuseAndTile pattern and we do co-tile for the `linalg.fill` and `linalg.reduce` , I will try to do the fusion down the road where those operations are already lowered to loops.
https://github.com/llvm/llvm-project/pull/125401
More information about the Mlir-commits
mailing list