[Mlir-commits] [mlir] [mlir][linalg] Produce canonical linalg.generic for im2col (PR #134675)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 7 09:17:32 PDT 2025
fabrizio-indirli wrote:
Before this patch, the following input IR:
```
func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
%0 = linalg.conv_2d_nhwc_fhwc
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
return %0 : tensor<1x14x14x16xf32>
}
```
would be converted to:
```
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<()[s0] -> (s0 floordiv 14)>
#map2 = affine_map<()[s0] -> (s0 mod 14)>
#map3 = affine_map<()[s0] -> (s0 floordiv 12)>
#map4 = affine_map<()[s0] -> (s0 mod 12)>
#map5 = affine_map<()[s0] -> ((s0 mod 12) floordiv 4)>
#map6 = affine_map<()[s0] -> (s0 mod 4)>
#map7 = affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>
#map8 = affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>
#map9 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map10 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
#map11 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
%collapsed = tensor.collapse_shape %arg1 [[0], [1, 2, 3]] : tensor<16x3x3x4xf32> into tensor<16x36xf32>
%collapsed_0 = tensor.collapse_shape %arg2 [[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
%0 = tensor.empty() : tensor<1x196x36xf32>
%1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%0 : tensor<1x196x36xf32>) {
^bb0(%out: f32):
%3 = linalg.index 0 : index
%4 = linalg.index 1 : index
%5 = linalg.index 2 : index
%c14 = arith.constant 14 : index
%c14_1 = arith.constant 14 : index
%c14_2 = arith.constant 14 : index
%6 = affine.apply #map1()[%4]
%7 = affine.apply #map2()[%4]
%c3 = arith.constant 3 : index
%c3_3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c4_4 = arith.constant 4 : index
%c12 = arith.constant 12 : index
%8 = affine.apply #map3()[%5]
%9 = affine.apply #map4()[%5]
%10 = affine.apply #map5()[%5]
%11 = affine.apply #map6()[%5]
%12 = affine.apply #map7()[%4, %5]
%13 = affine.apply #map8()[%4, %5]
%extracted = tensor.extract %arg0[%3, %12, %13, %11] : tensor<1x16x16x4xf32>
linalg.yield %extracted : f32
} -> tensor<1x196x36xf32>
%2 = linalg.generic {indexing_maps = [#map9, #map10, #map11], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%1, %collapsed : tensor<1x196x36xf32>, tensor<16x36xf32>) outs(%collapsed_0 : tensor<1x196x16xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%3 = arith.mulf %in, %in_1 : f32
%4 = arith.addf %3, %out : f32
linalg.yield %4 : f32
} -> tensor<1x196x16xf32>
%expanded = tensor.expand_shape %2 [[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
return %expanded : tensor<1x14x14x16xf32>
}
```
while with this patch it is transformed to:
```
#map = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
%collapsed = tensor.collapse_shape %arg1 [[0], [1, 2, 3]] : tensor<16x3x3x4xf32> into tensor<16x36xf32>
%collapsed_0 = tensor.collapse_shape %arg2 [[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
%0 = tensor.empty() : tensor<1x196x36xf32>
%1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x16x16x4xf32>) outs(%0 : tensor<1x196x36xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x196x36xf32>
%2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%1, %collapsed : tensor<1x196x36xf32>, tensor<16x36xf32>) outs(%collapsed_0 : tensor<1x196x16xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%3 = arith.mulf %in, %in_1 : f32
%4 = arith.addf %3, %out : f32
linalg.yield %4 : f32
} -> tensor<1x196x16xf32>
%expanded = tensor.expand_shape %2 [[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
return %expanded : tensor<1x14x14x16xf32>
}
```
Thus, the input tensor `%arg0` is now correctly reported in the linalg.generic's `ins()`, and the input access offsets are computed through a normal indexing maps. This simplifies the code quite a bit and produces more canonical linalg.generic ops.
However I see that the current approach is the result of a rewrite of a previous code, which was already producing canonical linalg.generic. I also found a couple of old discussions on the current code ([1](https://reviews.llvm.org/D144108) and [2](https://reviews.llvm.org/D144678)) where it was mentioned that
> I know this pattern exists in IREE, but it is a bit of a hack. The representation of the im2col as a linalg.generic doesnt work as well. In reality it is similar to a gather. If this is upstreamed, it might be worth doing this right and not use an unnecessarily higher-dimensional operation for representing the im2col.
Thus, I was wondering if the current code produces the non-canonical linalg on purpose for a specific reason? Since I couldn't ne sure, as I first step I modifed only one of the 4 rewrites in the transform. If anyone can confirm that this approach is correct, I'll be happy to update the other ones as well.
@qcolombet @nicolasvasilache @MaheshRavishankar @ThomasRaoux
https://github.com/llvm/llvm-project/pull/134675
More information about the Mlir-commits
mailing list