[Mlir-commits] [mlir] [mlir][linalg] Produce canonical linalg.generic for im2col (PR #134675)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 7 09:17:32 PDT 2025


fabrizio-indirli wrote:

Before this patch, the following input IR:
```
func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
    %0 = linalg.conv_2d_nhwc_fhwc
      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
       ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
      outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
    return %0 : tensor<1x14x14x16xf32>
}
```
would be converted to:
```
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<()[s0] -> (s0 floordiv 14)>
#map2 = affine_map<()[s0] -> (s0 mod 14)>
#map3 = affine_map<()[s0] -> (s0 floordiv 12)>
#map4 = affine_map<()[s0] -> (s0 mod 12)>
#map5 = affine_map<()[s0] -> ((s0 mod 12) floordiv 4)>
#map6 = affine_map<()[s0] -> (s0 mod 4)>
#map7 = affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>
#map8 = affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>
#map9 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map10 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
#map11 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
    %collapsed = tensor.collapse_shape %arg1 [[0], [1, 2, 3]] : tensor<16x3x3x4xf32> into tensor<16x36xf32>
    %collapsed_0 = tensor.collapse_shape %arg2 [[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
    %0 = tensor.empty() : tensor<1x196x36xf32>
    %1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%0 : tensor<1x196x36xf32>) {
    ^bb0(%out: f32):
      %3 = linalg.index 0 : index
      %4 = linalg.index 1 : index
      %5 = linalg.index 2 : index
      %c14 = arith.constant 14 : index
      %c14_1 = arith.constant 14 : index
      %c14_2 = arith.constant 14 : index
      %6 = affine.apply #map1()[%4]
      %7 = affine.apply #map2()[%4]
      %c3 = arith.constant 3 : index
      %c3_3 = arith.constant 3 : index
      %c4 = arith.constant 4 : index
      %c4_4 = arith.constant 4 : index
      %c12 = arith.constant 12 : index
      %8 = affine.apply #map3()[%5]
      %9 = affine.apply #map4()[%5]
      %10 = affine.apply #map5()[%5]
      %11 = affine.apply #map6()[%5]
      %12 = affine.apply #map7()[%4, %5]
      %13 = affine.apply #map8()[%4, %5]
      %extracted = tensor.extract %arg0[%3, %12, %13, %11] : tensor<1x16x16x4xf32>
      linalg.yield %extracted : f32
    } -> tensor<1x196x36xf32>
    %2 = linalg.generic {indexing_maps = [#map9, #map10, #map11], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%1, %collapsed : tensor<1x196x36xf32>, tensor<16x36xf32>) outs(%collapsed_0 : tensor<1x196x16xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %3 = arith.mulf %in, %in_1 : f32
      %4 = arith.addf %3, %out : f32
      linalg.yield %4 : f32
    } -> tensor<1x196x16xf32>
    %expanded = tensor.expand_shape %2 [[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
    return %expanded : tensor<1x14x14x16xf32>
}
```
while with this patch it is transformed to:
```
#map = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
    %collapsed = tensor.collapse_shape %arg1 [[0], [1, 2, 3]] : tensor<16x3x3x4xf32> into tensor<16x36xf32>
    %collapsed_0 = tensor.collapse_shape %arg2 [[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
    %0 = tensor.empty() : tensor<1x196x36xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x16x16x4xf32>) outs(%0 : tensor<1x196x36xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x196x36xf32>
    %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%1, %collapsed : tensor<1x196x36xf32>, tensor<16x36xf32>) outs(%collapsed_0 : tensor<1x196x16xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %3 = arith.mulf %in, %in_1 : f32
      %4 = arith.addf %3, %out : f32
      linalg.yield %4 : f32
    } -> tensor<1x196x16xf32>
    %expanded = tensor.expand_shape %2 [[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
    return %expanded : tensor<1x14x14x16xf32>
}
```
Thus, the input tensor `%arg0` is now correctly reported in the linalg.generic's `ins()`, and the input access offsets are computed through a normal indexing maps. This simplifies the code quite a bit and produces more canonical linalg.generic ops.

However I see that the current approach is the result of a rewrite of a previous code, which was already producing canonical linalg.generic. I also found a couple of old discussions on the current code ([1](https://reviews.llvm.org/D144108) and [2](https://reviews.llvm.org/D144678)) where it was mentioned that
> I know this pattern exists in IREE, but it is a bit of a hack. The representation of the im2col as a linalg.generic doesnt work as well. In reality it is similar to a gather. If this is upstreamed, it might be worth doing this right and not use an unnecessarily higher-dimensional operation for representing the im2col.

Thus, I was wondering if the current code produces the non-canonical linalg on purpose for a specific reason? Since I couldn't ne sure, as I first step I modifed only one of the 4 rewrites in the transform. If anyone can confirm that this approach is correct, I'll be happy to update the other ones as well.
@qcolombet  @nicolasvasilache  @MaheshRavishankar  @ThomasRaoux 

https://github.com/llvm/llvm-project/pull/134675


More information about the Mlir-commits mailing list