[Mlir-commits] [mlir] [mlir][linalg][conv] Flatten the channel dimension when vectorizing (PR #71918)
Andrzej Warzyński
llvmlistbot at llvm.org
Thu Nov 16 02:47:13 PST 2023
banach-space wrote:
> @banach-space I really like how your recent change significantly reduced the complexity compared to the previous approach.
@nicolasvasilache Thank you - I was about to write a long reply in which I say "thank you" for your observation (which, indeed, makes things incredibly simpler). But you have already noticed that :)
> Hopefully the shape_cast swaps nicely with the vector transfers next :)
That would be happening as a post vectorisation canonicalisation, right?
> Should we schedule a meeting to discuss this?
Yes - let me ping you offline :)
Regarding `linalg.generic`s with maps which are not permutations - we need to take a closer look and make sure that the resulting code for convolutions would be equally good. I can take another look at my examples, but I need to prioritise landing this first :) But yes, an important TODO.
As for "strided" loads/stores - super important TODO as well :)
In any case, I have refactored this patch following Nicolas' observation (apologies for over-complicating it so much before). Thank you, that was incredibly helpful 🙏🏻 .
Now, this change is still a "one off" pattern application within the vectoriser. Should this be done elsewhere? We'd need to match whatever `depthwiseConv1dSliceAsMulAcc` is generating post vectorisation, which would turn this into something quite complex.
Btw, I need to refine the tests - I am aware of that. And to benchmark.
Finally, to better visualise the current changes:
**Input**
```mlir
func.func @conv_dill_2(%input: memref<3x5x4xf32>,
%filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
linalg.depthwise_conv_1d_nwc_wc
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
outs(%output : memref<3x2x4xf32>)
return
}
```
**Vectorisation _without_ shape casting**
```mlir
func.func @conv_dill_2(%arg0: memref<3x5x4xf32>, %arg1: memref<2x4xf32>, %arg2: memref<3x2x4xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<3x5x4xf32>, vector<3x4x4xf32>
%1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<2x4xf32>, vector<2x4xf32>
%2 = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<3x2x4xf32>, vector<3x2x4xf32>
%3 = vector.extract_strided_slice %0 {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
%4 = vector.extract_strided_slice %0 {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
%5 = vector.extract %1[0] : vector<4xf32> from vector<2x4xf32>
%6 = vector.extract %1[1] : vector<4xf32> from vector<2x4xf32>
%7 = vector.broadcast %5 : vector<4xf32> to vector<3x2x4xf32>
%8 = vector.fma %3, %7, %2 : vector<3x2x4xf32>
%9 = vector.broadcast %6 : vector<4xf32> to vector<3x2x4xf32>
%10 = vector.fma %4, %9, %8 : vector<3x2x4xf32>
vector.transfer_write %10, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<3x2x4xf32>, memref<3x2x4xf32>
return
}
```
**Vectorisation _with_ shape casting**
```mlir
func.func @conv_dill_2(%arg0: memref<3x5x4xf32>, %arg1: memref<2x4xf32>, %arg2: memref<3x2x4xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<3x5x4xf32>, vector<3x4x4xf32>
%1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<2x4xf32>, vector<2x4xf32>
%2 = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<3x2x4xf32>, vector<3x2x4xf32>
%3 = vector.extract_strided_slice %0 {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
%4 = vector.extract_strided_slice %0 {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
%5 = vector.extract %1[0] : vector<4xf32> from vector<2x4xf32>
%6 = vector.extract %1[1] : vector<4xf32> from vector<2x4xf32>
%7 = vector.shape_cast %3 : vector<3x2x4xf32> to vector<3x8xf32>
%8 = vector.shape_cast %2 : vector<3x2x4xf32> to vector<3x8xf32>
%9 = vector.broadcast %5 : vector<4xf32> to vector<3x2x4xf32>
%10 = vector.shape_cast %9 : vector<3x2x4xf32> to vector<3x8xf32>
%11 = vector.fma %7, %10, %8 : vector<3x8xf32>
%12 = vector.shape_cast %4 : vector<3x2x4xf32> to vector<3x8xf32>
%13 = vector.broadcast %6 : vector<4xf32> to vector<3x2x4xf32>
%14 = vector.shape_cast %13 : vector<3x2x4xf32> to vector<3x8xf32>
%15 = vector.fma %12, %14, %11 : vector<3x8xf32>
%16 = vector.shape_cast %15 : vector<3x8xf32> to vector<3x2x4xf32>
vector.transfer_write %16, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<3x2x4xf32>, memref<3x2x4xf32>
return
}
```
https://github.com/llvm/llvm-project/pull/71918
More information about the Mlir-commits
mailing list