[Mlir-commits] [mlir] [mlir][PartialReductionTilingInterface] Generalize implementation of `tileUsingSCF` for `ReductionTilingStrategy::PartialOuterReduction`. (PR #143467)

Pietro Ghiglio llvmlistbot at llvm.org
Tue Sep 23 02:53:45 PDT 2025


PietroGhg wrote:

> > Note that the op doesn't implement destination-passing style, so I couldn't figure out how to implement tiling for this op using the `FullReduction` tiling strategy, and I implemented it using the `partialReductionTilingInterface`,
> 
> The `FullReduction` tiling should work on an operation that is not destination style passing either... If the transformation is failing for that, thats just a wrong check. We should be able to fix that if you provide a repro.

There weren't any failing checks, but I simply couldn't find a way to tile the op on the reduction dimension: the `getTiledImplementation` function doesn't expose any kind of "accumulator" variable, and I just don't see how I can implement tiling for a reduction op without some kind of accumulator. A destination-passing style op doesn't need that because (at least in the cases I looked into, `linalg.reduce` and `linalg.matmul`) the "destination" operand also acts as an accumulator while tiling.
For `linalg.reduce` the `FullReduction` tiling strategy, starting from 
```
func.func @reduce_columns(%input: tensor<128x256xf32>) -> tensor<128xf32> {
  %c0 = arith.constant 0.0 : f32
  %init = tensor.empty() : tensor<128xf32>
  %zero_init = linalg.fill ins(%c0 : f32) outs(%init : tensor<128xf32>) -> tensor<128xf32>
  
  %result = linalg.reduce ins(%input : tensor<128x256xf32>) 
                         outs(%zero_init : tensor<128xf32>) 
                         dimensions = [1] 
    (%in: f32, %acc: f32) {
      %sum = arith.addf %in, %acc : f32
      linalg.yield %sum : f32
    }
  
  return %result : tensor<128xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
    %red = transform.structured.match ops{["linalg.reduce"]} in %arg1
      : (!transform.any_op) -> !transform.any_op
    %a, %loops:2 = transform.structured.tile_using_for %red tile_sizes [64, 64]
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
    transform.yield
  }
}
```
produces
```
  func.func @reduce_columns(%arg0: tensor<128x256xf32>) -> tensor<128xf32> {
    %c64 = arith.constant 64 : index
    %c256 = arith.constant 256 : index
    %c128 = arith.constant 128 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<128xf32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128xf32>) -> tensor<128xf32>
    %2 = scf.for %arg1 = %c0 to %c128 step %c64 iter_args(%arg2 = %1) -> (tensor<128xf32>) {
      %3 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
        %extracted_slice = tensor.extract_slice %arg0[%arg1, %arg3] [64, 64] [1, 1] : tensor<128x256xf32> to tensor<64x64xf32>
        %extracted_slice_0 = tensor.extract_slice %arg4[%arg1] [64] [1] : tensor<128xf32> to tensor<64xf32>
        %reduced = linalg.reduce ins(%extracted_slice : tensor<64x64xf32>) outs(%extracted_slice_0 : tensor<64xf32>) dimensions = [1] 
          (%in: f32, %init: f32) {
            %4 = arith.addf %in, %init : f32
            linalg.yield %4 : f32
          }
        %inserted_slice = tensor.insert_slice %reduced into %arg4[%arg1] [64] [1] : tensor<64xf32> into tensor<128xf32>
        scf.yield %inserted_slice : tensor<128xf32>
      }
      scf.yield %3 : tensor<128xf32>
    }
    return %2 : tensor<128xf32>
  }
```
whereas with my previous impl of the `partialReductionInterface` for the `myred` op, we started from

```
func.func @accumsum_3d(%arg0: tensor<30x256x1024xf32>) -> tensor<30x1024xf32> {
  %0 = "myred"(%arg0) : (tensor<30x256x1024xf32>) -> tensor<30x1024xf32>
  return %0 : tensor<30x1024xf32>
}
```
and got to
```
   func.func @accumsum_3d
(%[[ARG0:.*]]: tensor<30x256x1024xf32>) -> tensor<30x1024xf32> {
   %[[C512:.*]] = arith.constant 512 : index
   %[[C64:.*]] = arith.constant 64 : index
   %[[C1:.*]] = arith.constant 1 : index
   %[[C1024:.*]] = arith.constant 1024 : index
   %[[C256:.*]] = arith.constant 256 : index
   %[[C30:.*]] = arith.constant 30 : index
   %[[C0:.*]] = arith.constant 0 : index
   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
   %[[n0:.*]] = tensor.empty() : tensor<30x1024xf32>
    %[[n1:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[n0]] : tensor<30x1024xf32>) -> tensor<30x1024xf32>
    %[[n2:.*]] = scf.for %[[ARG1:.*]] = %[[C0]] to %[[C30]] step %[[C1]] iter_args(%[[ARG2:.*]] = %[[n1]]) -> (tensor<30x1024xf32>) {
      %[[n3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C256]] step %[[C64]] iter_args(%[[ARG4:.*]] = %[[ARG2]]) -> (tensor<30x1024xf32>) {
        %[[n4:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C1024]] step %[[C512]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<30x1024xf32>) {
          %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG6]][%[[ARG1]], %[[ARG5]]] [1, 512] [1, 1] : tensor<30x1024xf32> to tensor<512xf32>
          %[[EXTRACTED_SLICE_0:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], %[[ARG3]], %[[ARG5]]] [1, 64, 512] [1, 1, 1] : tensor<30x256x1024xf32> to tensor<64x512xf32>
          %[[n5:.*]] = "myred"(%[[EXTRACTED_SLICE_0]]) <{tiled}> : (tensor<64x512xf32>) -> tensor<512xf32>
          %[[n6:.*]] = arith.addf %[[EXTRACTED_SLICE]], %[[n5]] : tensor<512xf32>
          %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[n6]] into %[[ARG6]][%[[ARG1]], %[[ARG5]]] [1, 512] [1, 1] : tensor<512xf32> into tensor<30x1024xf32>
          scf.yield %[[INSERTED_SLICE]] : tensor<30x1024xf32>
        }
        scf.yield %[[n4]] : tensor<30x1024xf32>
      }
      scf.yield %[[n3]] : tensor<30x1024xf32>
    }
    return %[[n2]] : tensor<30x1024xf32>
  }
}
```
Sorry for the lengthy reply, maybe we are going a bit out of topic, but I couldn't see how to achieve the above for my op using the `FullReduction` tiling strategy 


https://github.com/llvm/llvm-project/pull/143467


More information about the Mlir-commits mailing list